update TF models (beta version)

This commit is contained in:
D-X-Y
2020-01-05 22:19:38 +11:00
parent e6ca3628ce
commit 5ac5060a33
18 changed files with 1253 additions and 44 deletions

View File

@@ -1,7 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from os import path as osp
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
@@ -126,6 +125,11 @@ def obtain_search_model(config):
elif config.search_mode == 'shape':
return SearchShapeCifarResNet(config.module, config.depth, config.class_num)
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
elif config.arch == 'simres':
from .shape_searchs import SearchWidthSimResNet
if config.search_mode == 'width':
return SearchWidthSimResNet(config.depth, config.class_num)
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
else:
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
elif config.dataset == 'imagenet':
@@ -140,6 +144,7 @@ def obtain_search_model(config):
def load_net_from_checkpoint(checkpoint):
import torch
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
checkpoint = torch.load(checkpoint)
model_config = dict2config(checkpoint['model-config'], None)