This commit is contained in:
D-X-Y
2021-05-26 01:53:44 -07:00
parent 30fb8fad67
commit 299c8a085b
12 changed files with 137 additions and 115 deletions

View File

@@ -8,24 +8,44 @@
import os, torch
def obtain_nas_infer_model(config, extra_model_path=None):
if config.arch == 'dxys':
from .DXYs import CifarNet, ImageNet, Networks
from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
xdata = torch.load(extra_model_path)
current_epoch = xdata['epoch']
genotype_dict = xdata['genotypes'][current_epoch-1]
genotype = build_genotype_from_dict(genotype_dict)
if config.arch == "dxys":
from .DXYs import CifarNet, ImageNet, Networks
from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError(
"When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format(
extra_model_path
)
)
xdata = torch.load(extra_model_path)
current_epoch = xdata["epoch"]
genotype_dict = xdata["genotypes"][current_epoch - 1]
genotype = build_genotype_from_dict(genotype_dict)
else:
genotype = Networks[config.genotype]
if config.dataset == "cifar":
return CifarNet(
config.ichannel,
config.layers,
config.stem_multi,
config.auxiliary,
genotype,
config.class_num,
)
elif config.dataset == "imagenet":
return ImageNet(
config.ichannel,
config.layers,
config.auxiliary,
genotype,
config.class_num,
)
else:
raise ValueError("invalid dataset : {:}".format(config.dataset))
else:
genotype = Networks[config.genotype]
if config.dataset == 'cifar':
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
elif config.dataset == 'imagenet':
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
else:
raise ValueError('invalid nas arch type : {:}'.format(config.arch))
raise ValueError("invalid nas arch type : {:}".format(config.arch))