update CVPR-2019-GDAS re-train NASNet-search-space searched models
This commit is contained in:
@@ -168,5 +168,15 @@ Networks = {'DARTS_V1': DARTS_V1,
|
||||
'SETN' : SETN,
|
||||
}
|
||||
|
||||
# This function will return a Genotype from a dict.
|
||||
def build_genotype_from_dict(xdict):
|
||||
import pdb; pdb.set_trace()
|
||||
def remove_value(nodes):
|
||||
return [tuple([(x[0], x[1]) for x in node]) for node in nodes]
|
||||
genotype = Genotype(
|
||||
normal=remove_value(xdict['normal']),
|
||||
normal_concat=xdict['normal_concat'],
|
||||
reduce=remove_value(xdict['reduce']),
|
||||
reduce_concat=xdict['reduce_concat'],
|
||||
connectN=None, connects=None
|
||||
)
|
||||
return genotype
|
||||
|
@@ -6,12 +6,22 @@
|
||||
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
|
||||
##################################################
|
||||
|
||||
import torch
|
||||
import os, torch
|
||||
|
||||
def obtain_nas_infer_model(config):
|
||||
def obtain_nas_infer_model(config, extra_model_path=None):
|
||||
|
||||
if config.arch == 'dxys':
|
||||
from .DXYs import CifarNet, ImageNet, Networks
|
||||
genotype = Networks[config.genotype]
|
||||
from .DXYs import build_genotype_from_dict
|
||||
if config.genotype is None:
|
||||
if extra_model_path is not None and not os.path.isfile(extra_model_path):
|
||||
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
|
||||
xdata = torch.load(extra_model_path)
|
||||
current_epoch = xdata['epoch']
|
||||
genotype_dict = xdata['genotypes'][current_epoch-1]
|
||||
genotype = build_genotype_from_dict(genotype_dict)
|
||||
else:
|
||||
genotype = Networks[config.genotype]
|
||||
if config.dataset == 'cifar':
|
||||
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
|
||||
elif config.dataset == 'imagenet':
|
||||
|
Reference in New Issue
Block a user