update TF models (beta version)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
from os import path as osp
|
||||
|
||||
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
|
||||
@@ -126,6 +125,11 @@ def obtain_search_model(config):
|
||||
elif config.search_mode == 'shape':
|
||||
return SearchShapeCifarResNet(config.module, config.depth, config.class_num)
|
||||
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
|
||||
elif config.arch == 'simres':
|
||||
from .shape_searchs import SearchWidthSimResNet
|
||||
if config.search_mode == 'width':
|
||||
return SearchWidthSimResNet(config.depth, config.class_num)
|
||||
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
|
||||
else:
|
||||
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
|
||||
elif config.dataset == 'imagenet':
|
||||
@@ -140,6 +144,7 @@ def obtain_search_model(config):
|
||||
|
||||
|
||||
def load_net_from_checkpoint(checkpoint):
|
||||
import torch
|
||||
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
|
||||
checkpoint = torch.load(checkpoint)
|
||||
model_config = dict2config(checkpoint['model-config'], None)
|
||||
|
Reference in New Issue
Block a user