update README

This commit is contained in:
D-X-Y
2019-09-28 20:18:18 +10:00
parent 180702ab8e
commit f8f3f382e0
18 changed files with 9 additions and 779 deletions

View File

@@ -11,15 +11,12 @@ from .clone_weights import init_from_model
def get_cifar_models(config):
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
if config.arch == 'resnet':
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
elif config.arch == 'densenet':
return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck)
elif config.arch == 'wideresnet':
return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout)
else:
@@ -44,10 +41,8 @@ def get_cifar_models(config):
def get_imagenet_models(config):
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
return get_imagenet_models_basic(config)
# NAS searched architecture
elif super_type.startswith('infer'):
if super_type.startswith('infer'):
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
infer_mode = super_type.split('-')[1]
if infer_mode == 'shape':
@@ -65,20 +60,6 @@ def get_imagenet_models(config):
raise ValueError('invalid super-type : {:}'.format(super_type))
def get_imagenet_models_basic(config):
from .ImagenetResNet import ResNet
from .MobileNet import MobileNetV2
from .ShuffleNetV2 import ShuffleNetV2
if config.arch == 'resnet':
return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group)
elif config.arch == 'MobileNetV2':
return MobileNetV2(config.class_num, config.width_mult, config.input_channel, config.last_channel, config.block_name, config.dropout)
elif config.arch == 'ShuffleNetV2':
return ShuffleNetV2(config.class_num, config.stages)
else:
raise ValueError('invalid arch : {:}'.format( config.arch ))
def obtain_model(config):
if config.dataset == 'cifar':
return get_cifar_models(config)