update CVPR-2019-GDAS re-train NASNet-search-space searched models

This commit is contained in:
D-X-Y
2020-03-06 19:29:07 +11:00
parent 8b6df42f1f
commit 9a83814a46
17 changed files with 278 additions and 21 deletions

View File

@@ -168,5 +168,15 @@ Networks = {'DARTS_V1': DARTS_V1,
'SETN' : SETN,
}
# This function will return a Genotype from a dict.
def build_genotype_from_dict(xdict):
import pdb; pdb.set_trace()
def remove_value(nodes):
return [tuple([(x[0], x[1]) for x in node]) for node in nodes]
genotype = Genotype(
normal=remove_value(xdict['normal']),
normal_concat=xdict['normal_concat'],
reduce=remove_value(xdict['reduce']),
reduce_concat=xdict['reduce_concat'],
connectN=None, connects=None
)
return genotype

View File

@@ -6,12 +6,22 @@
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
##################################################
import torch
import os, torch
def obtain_nas_infer_model(config):
def obtain_nas_infer_model(config, extra_model_path=None):
if config.arch == 'dxys':
from .DXYs import CifarNet, ImageNet, Networks
genotype = Networks[config.genotype]
from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
xdata = torch.load(extra_model_path)
current_epoch = xdata['epoch']
genotype_dict = xdata['genotypes'][current_epoch-1]
genotype = build_genotype_from_dict(genotype_dict)
else:
genotype = Networks[config.genotype]
if config.dataset == 'cifar':
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
elif config.dataset == 'imagenet':