update README

This commit is contained in:
D-X-Y
2019-11-15 17:40:15 +11:00
parent 0630867505
commit c3672648d7
3 changed files with 116 additions and 3 deletions

View File

@@ -38,12 +38,15 @@ def get_search_spaces(xtype, name):
def get_cifar_models(config):
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
if config.arch == 'resnet':
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
elif config.arch == 'densenet':
return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck)
elif config.arch == 'wideresnet':
return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout)
else:
@@ -68,8 +71,13 @@ def get_cifar_models(config):
def get_imagenet_models(config):
super_type = getattr(config, 'super_type', 'basic')
# NAS searched architecture
if super_type.startswith('infer'):
if super_type == 'basic':
from .ImagenetResNet import ResNet
if config.arch == 'resnet':
return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group)
else:
raise ValueError('invalid arch : {:}'.format( config.arch ))
elif super_type.startswith('infer'): # NAS searched architecture
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
infer_mode = super_type.split('-')[1]
if infer_mode == 'shape':