update codes

This commit is contained in:
D-X-Y
2020-01-12 01:42:17 +11:00
parent 654015bf9d
commit 33384a78af
15 changed files with 288 additions and 21 deletions

View File

@@ -13,20 +13,21 @@ from config_utils import dict2config
from .SharedUtils import change_key
from .cell_searchs import CellStructure, CellArchitectures
# Cell-based NAS Models
def get_cell_based_tiny_net(config):
super_type = getattr(config, 'super_type', 'basic')
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
if super_type == 'basic' and config.name in group_names:
from .cell_searchs import nas_super_nets
from .cell_searchs import nas102_super_nets as nas_super_nets
try:
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
except:
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
elif super_type == 'l2s-base' and config.name in group_names:
from .l2s_cell_searchs import nas_super_nets
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space \
,config.n_piece)
elif super_type == 'nasnet-super':
from .cell_searchs import nasnet_super_nets as nas_super_nets
return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \
config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats)
elif config.name == 'infer.tiny':
from .cell_infers import TinyNetwork
return TinyNetwork(config.C, config.N, config.genotype, config.num_classes)