Add MobileNetV2

This commit is contained in:
D-X-Y
2020-03-30 16:20:01 -07:00
parent d70b3c528c
commit e29c86d479
3 changed files with 128 additions and 0 deletions

View File

@@ -110,8 +110,11 @@ def get_imagenet_models(config):
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
from .ImagenetResNet import ResNet
from .ImageNet_MobileNetV2 import MobileNetV2
if config.arch == 'resnet':
return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group)
elif config.arch == 'mobilenet_v2':
return MobileNetV2(config.class_num, config.width_multi, config.input_channel, config.last_channel, 'InvertedResidual', config.dropout)
else:
raise ValueError('invalid arch : {:}'.format( config.arch ))
elif super_type.startswith('infer'): # NAS searched architecture