update ENAS

This commit is contained in:
D-X-Y
2019-11-09 01:36:31 +11:00
parent 1da5b49018
commit 34ba8053de
7 changed files with 533 additions and 12 deletions

View File

@@ -16,18 +16,10 @@ from .cell_searchs import CellStructure, CellArchitectures
# Cell-based NAS Models
def get_cell_based_tiny_net(config):
if config.name == 'DARTS-V1':
from .cell_searchs import TinyNetworkDartsV1
return TinyNetworkDartsV1(config.C, config.N, config.max_nodes, config.num_classes, config.space)
elif config.name == 'DARTS-V2':
from .cell_searchs import TinyNetworkDartsV2
return TinyNetworkDartsV2(config.C, config.N, config.max_nodes, config.num_classes, config.space)
elif config.name == 'GDAS':
from .cell_searchs import TinyNetworkGDAS
return TinyNetworkGDAS(config.C, config.N, config.max_nodes, config.num_classes, config.space)
elif config.name == 'SETN':
from .cell_searchs import TinyNetworkSETN
return TinyNetworkSETN(config.C, config.N, config.max_nodes, config.num_classes, config.space)
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS']
from .cell_searchs import nas_super_nets
if config.name in group_names:
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
elif config.name == 'infer.tiny':
from .cell_infers import TinyNetwork
return TinyNetwork(config.C, config.N, config.genotype, config.num_classes)