This commit is contained in:
D-X-Y
2021-05-26 01:53:44 -07:00
parent 30fb8fad67
commit 299c8a085b
12 changed files with 137 additions and 115 deletions

View File

@@ -1,5 +1,5 @@
# Performance-Aware Template Network for One-Shot Neural Architecture Search
from .CifarNet import NetworkCIFAR as CifarNet
from .ImageNet import NetworkImageNet as ImageNet
from .CifarNet import NetworkCIFAR as CifarNet
from .ImageNet import NetworkImageNet as ImageNet
from .genotypes import Networks
from .genotypes import build_genotype_from_dict

View File

@@ -8,24 +8,44 @@
import os, torch
def obtain_nas_infer_model(config, extra_model_path=None):
if config.arch == 'dxys':
from .DXYs import CifarNet, ImageNet, Networks
from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
xdata = torch.load(extra_model_path)
current_epoch = xdata['epoch']
genotype_dict = xdata['genotypes'][current_epoch-1]
genotype = build_genotype_from_dict(genotype_dict)
if config.arch == "dxys":
from .DXYs import CifarNet, ImageNet, Networks
from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError(
"When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format(
extra_model_path
)
)
xdata = torch.load(extra_model_path)
current_epoch = xdata["epoch"]
genotype_dict = xdata["genotypes"][current_epoch - 1]
genotype = build_genotype_from_dict(genotype_dict)
else:
genotype = Networks[config.genotype]
if config.dataset == "cifar":
return CifarNet(
config.ichannel,
config.layers,
config.stem_multi,
config.auxiliary,
genotype,
config.class_num,
)
elif config.dataset == "imagenet":
return ImageNet(
config.ichannel,
config.layers,
config.auxiliary,
genotype,
config.class_num,
)
else:
raise ValueError("invalid dataset : {:}".format(config.dataset))
else:
genotype = Networks[config.genotype]
if config.dataset == 'cifar':
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
elif config.dataset == 'imagenet':
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
else:
raise ValueError('invalid nas arch type : {:}'.format(config.arch))
raise ValueError("invalid nas arch type : {:}".format(config.arch))