beta-0.1
This commit is contained in:
@@ -3,10 +3,16 @@
|
||||
##################################################
|
||||
import torch
|
||||
from os import path as osp
|
||||
|
||||
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
|
||||
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
|
||||
'CellStructure', 'CellArchitectures'
|
||||
]
|
||||
|
||||
# useful modules
|
||||
from config_utils import dict2config
|
||||
from .SharedUtils import change_key
|
||||
from .clone_weights import init_from_model
|
||||
from .cell_searchs import CellStructure, CellArchitectures
|
||||
|
||||
# Cell-based NAS Models
|
||||
def get_cell_based_tiny_net(config):
|
||||
@@ -22,9 +28,13 @@ def get_cell_based_tiny_net(config):
|
||||
elif config.name == 'SETN':
|
||||
from .cell_searchs import TinyNetworkSETN
|
||||
return TinyNetworkSETN(config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
elif config.name == 'infer.tiny':
|
||||
from .cell_infers import TinyNetwork
|
||||
return TinyNetwork(config.C, config.N, config.genotype, config.num_classes)
|
||||
else:
|
||||
raise ValueError('invalid network name : {:}'.format(config.name))
|
||||
|
||||
|
||||
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||
def get_search_spaces(xtype, name):
|
||||
if xtype == 'cell':
|
||||
|
Reference in New Issue
Block a user