update GDAS and SETN
This commit is contained in:
@@ -3,11 +3,36 @@
|
||||
##################################################
|
||||
import torch
|
||||
from os import path as osp
|
||||
# our modules
|
||||
# useful modules
|
||||
from config_utils import dict2config
|
||||
from .SharedUtils import change_key
|
||||
from .clone_weights import init_from_model
|
||||
|
||||
# Cell-based NAS Models
|
||||
def get_cell_based_tiny_net(config):
|
||||
if config.name == 'DARTS-V1':
|
||||
from .cell_searchs import TinyNetworkDartsV1
|
||||
return TinyNetworkDartsV1(config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
elif config.name == 'DARTS-V2':
|
||||
from .cell_searchs import TinyNetworkDartsV2
|
||||
return TinyNetworkDartsV2(config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
elif config.name == 'GDAS':
|
||||
from .cell_searchs import TinyNetworkGDAS
|
||||
return TinyNetworkGDAS(config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
elif config.name == 'SETN':
|
||||
from .cell_searchs import TinyNetworkSETN
|
||||
return TinyNetworkSETN(config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
else:
|
||||
raise ValueError('invalid network name : {:}'.format(config.name))
|
||||
|
||||
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||
def get_search_spaces(xtype, name):
|
||||
if xtype == 'cell':
|
||||
from .cell_operations import SearchSpaceNames
|
||||
return SearchSpaceNames[name]
|
||||
else:
|
||||
raise ValueError('invalid search-space type is {:}'.format(xtype))
|
||||
|
||||
|
||||
def get_cifar_models(config):
|
||||
from .CifarResNet import CifarResNet
|
||||
@@ -22,9 +47,9 @@ def get_cifar_models(config):
|
||||
else:
|
||||
raise ValueError('invalid module type : {:}'.format(config.arch))
|
||||
elif super_type.startswith('infer'):
|
||||
from .infers import InferWidthCifarResNet
|
||||
from .infers import InferDepthCifarResNet
|
||||
from .infers import InferCifarResNet
|
||||
from .shape_infers import InferWidthCifarResNet
|
||||
from .shape_infers import InferDepthCifarResNet
|
||||
from .shape_infers import InferCifarResNet
|
||||
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
|
||||
infer_mode = super_type.split('-')[1]
|
||||
if infer_mode == 'width':
|
||||
@@ -46,8 +71,8 @@ def get_imagenet_models(config):
|
||||
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
|
||||
infer_mode = super_type.split('-')[1]
|
||||
if infer_mode == 'shape':
|
||||
from .infers import InferImagenetResNet
|
||||
from .infers import InferMobileNetV2
|
||||
from .shape_infers import InferImagenetResNet
|
||||
from .shape_infers import InferMobileNetV2
|
||||
if config.arch == 'resnet':
|
||||
return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual)
|
||||
elif config.arch == "MobileNetV2":
|
||||
@@ -72,9 +97,9 @@ def obtain_model(config):
|
||||
def obtain_search_model(config):
|
||||
if config.dataset == 'cifar':
|
||||
if config.arch == 'resnet':
|
||||
from .searchs import SearchWidthCifarResNet
|
||||
from .searchs import SearchDepthCifarResNet
|
||||
from .searchs import SearchShapeCifarResNet
|
||||
from .shape_searchs import SearchWidthCifarResNet
|
||||
from .shape_searchs import SearchDepthCifarResNet
|
||||
from .shape_searchs import SearchShapeCifarResNet
|
||||
if config.search_mode == 'width':
|
||||
return SearchWidthCifarResNet(config.module, config.depth, config.class_num)
|
||||
elif config.search_mode == 'depth':
|
||||
@@ -85,7 +110,7 @@ def obtain_search_model(config):
|
||||
else:
|
||||
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
|
||||
elif config.dataset == 'imagenet':
|
||||
from .searchs import SearchShapeImagenetResNet
|
||||
from .shape_searchs import SearchShapeImagenetResNet
|
||||
assert config.search_mode == 'shape', 'invalid search-mode : {:}'.format( config.search_mode )
|
||||
if config.arch == 'resnet':
|
||||
return SearchShapeImagenetResNet(config.block_name, config.layers, config.deep_stem, config.class_num)
|
||||
|
Reference in New Issue
Block a user