update GDAS
This commit is contained in:
@@ -16,10 +16,15 @@ from .cell_searchs import CellStructure, CellArchitectures
|
||||
|
||||
# Cell-based NAS Models
|
||||
def get_cell_based_tiny_net(config):
|
||||
super_type = getattr(config, 'super_type', 'basic')
|
||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
||||
from .cell_searchs import nas_super_nets
|
||||
if config.name in group_names:
|
||||
if super_type == 'basic' and config.name in group_names:
|
||||
from .cell_searchs import nas_super_nets
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
elif super_type == 'l2s-base' and config.name in group_names:
|
||||
from .l2s_cell_searchs import nas_super_nets
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space \
|
||||
,config.n_piece)
|
||||
elif config.name == 'infer.tiny':
|
||||
from .cell_infers import TinyNetwork
|
||||
return TinyNetwork(config.C, config.N, config.genotype, config.num_classes)
|
||||
|
Reference in New Issue
Block a user