Update TAS abd FBV2 for NAS-Bench
This commit is contained in:
@@ -12,8 +12,8 @@ __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_ci
|
||||
|
||||
# useful modules
|
||||
from config_utils import dict2config
|
||||
from .SharedUtils import change_key
|
||||
from .cell_searchs import CellStructure, CellArchitectures
|
||||
from models.SharedUtils import change_key
|
||||
from models.cell_searchs import CellStructure, CellArchitectures
|
||||
|
||||
|
||||
# Cell-based NAS Models
|
||||
@@ -27,6 +27,10 @@ def get_cell_based_tiny_net(config):
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
|
||||
except:
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
elif super_type == 'search-shape':
|
||||
from .shape_searchs import GenericNAS301Model
|
||||
genotype = CellStructure.str2structure(config.genotype)
|
||||
return GenericNAS301Model(config.candidate_Cs, config.max_num_Cs, genotype, config.num_classes, config.affine, config.track_running_stats)
|
||||
elif super_type == 'nasnet-super':
|
||||
from .cell_searchs import nasnet_super_nets as nas_super_nets
|
||||
return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \
|
||||
|
Reference in New Issue
Block a user