support get_net_config for NAS-Bench-201

This commit is contained in:
D-X-Y
2020-02-02 17:20:38 +11:00
parent 133fd21ecc
commit 25e529f788
3 changed files with 69 additions and 9 deletions

View File

@@ -16,6 +16,7 @@ from .cell_searchs import CellStructure, CellArchitectures
# Cell-based NAS Models
def get_cell_based_tiny_net(config):
if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict
super_type = getattr(config, 'super_type', 'basic')
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
if super_type == 'basic' and config.name in group_names:
@@ -30,7 +31,12 @@ def get_cell_based_tiny_net(config):
config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats)
elif config.name == 'infer.tiny':
from .cell_infers import TinyNetwork
return TinyNetwork(config.C, config.N, config.genotype, config.num_classes)
if hasattr(config, 'genotype'):
genotype = config.genotype
elif hasattr(config, 'arch_str'):
genotype = CellStructure.str2structure(config.arch_str)
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
else:
raise ValueError('invalid network name : {:}'.format(config.name))