Prototype generic nas model.

This commit is contained in:
D-X-Y
2020-07-16 10:34:34 +00:00
parent a99df6dc31
commit 68f9d037eb
7 changed files with 510 additions and 4 deletions

View File

@@ -20,7 +20,7 @@ from .cell_searchs import CellStructure, CellArchitectures
def get_cell_based_tiny_net(config):
if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict
super_type = getattr(config, 'super_type', 'basic')
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM', 'generic']
if super_type == 'basic' and config.name in group_names:
from .cell_searchs import nas201_super_nets as nas_super_nets
try: