update README
This commit is contained in:
@@ -38,12 +38,15 @@ def get_search_spaces(xtype, name):
|
||||
|
||||
def get_cifar_models(config):
|
||||
from .CifarResNet import CifarResNet
|
||||
from .CifarDenseNet import DenseNet
|
||||
from .CifarWideResNet import CifarWideResNet
|
||||
|
||||
super_type = getattr(config, 'super_type', 'basic')
|
||||
if super_type == 'basic':
|
||||
if config.arch == 'resnet':
|
||||
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
|
||||
elif config.arch == 'densenet':
|
||||
return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck)
|
||||
elif config.arch == 'wideresnet':
|
||||
return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout)
|
||||
else:
|
||||
@@ -68,8 +71,13 @@ def get_cifar_models(config):
|
||||
|
||||
def get_imagenet_models(config):
|
||||
super_type = getattr(config, 'super_type', 'basic')
|
||||
# NAS searched architecture
|
||||
if super_type.startswith('infer'):
|
||||
if super_type == 'basic':
|
||||
from .ImagenetResNet import ResNet
|
||||
if config.arch == 'resnet':
|
||||
return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group)
|
||||
else:
|
||||
raise ValueError('invalid arch : {:}'.format( config.arch ))
|
||||
elif super_type.startswith('infer'): # NAS searched architecture
|
||||
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
|
||||
infer_mode = super_type.split('-')[1]
|
||||
if infer_mode == 'shape':
|
||||
|
Reference in New Issue
Block a user