Fix bugs
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# Performance-Aware Template Network for One-Shot Neural Architecture Search
|
||||
from .CifarNet import NetworkCIFAR as CifarNet
|
||||
from .ImageNet import NetworkImageNet as ImageNet
|
||||
from .CifarNet import NetworkCIFAR as CifarNet
|
||||
from .ImageNet import NetworkImageNet as ImageNet
|
||||
from .genotypes import Networks
|
||||
from .genotypes import build_genotype_from_dict
|
||||
|
@@ -8,24 +8,44 @@
|
||||
|
||||
import os, torch
|
||||
|
||||
|
||||
def obtain_nas_infer_model(config, extra_model_path=None):
|
||||
|
||||
if config.arch == 'dxys':
|
||||
from .DXYs import CifarNet, ImageNet, Networks
|
||||
from .DXYs import build_genotype_from_dict
|
||||
if config.genotype is None:
|
||||
if extra_model_path is not None and not os.path.isfile(extra_model_path):
|
||||
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
|
||||
xdata = torch.load(extra_model_path)
|
||||
current_epoch = xdata['epoch']
|
||||
genotype_dict = xdata['genotypes'][current_epoch-1]
|
||||
genotype = build_genotype_from_dict(genotype_dict)
|
||||
|
||||
if config.arch == "dxys":
|
||||
from .DXYs import CifarNet, ImageNet, Networks
|
||||
from .DXYs import build_genotype_from_dict
|
||||
|
||||
if config.genotype is None:
|
||||
if extra_model_path is not None and not os.path.isfile(extra_model_path):
|
||||
raise ValueError(
|
||||
"When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format(
|
||||
extra_model_path
|
||||
)
|
||||
)
|
||||
xdata = torch.load(extra_model_path)
|
||||
current_epoch = xdata["epoch"]
|
||||
genotype_dict = xdata["genotypes"][current_epoch - 1]
|
||||
genotype = build_genotype_from_dict(genotype_dict)
|
||||
else:
|
||||
genotype = Networks[config.genotype]
|
||||
if config.dataset == "cifar":
|
||||
return CifarNet(
|
||||
config.ichannel,
|
||||
config.layers,
|
||||
config.stem_multi,
|
||||
config.auxiliary,
|
||||
genotype,
|
||||
config.class_num,
|
||||
)
|
||||
elif config.dataset == "imagenet":
|
||||
return ImageNet(
|
||||
config.ichannel,
|
||||
config.layers,
|
||||
config.auxiliary,
|
||||
genotype,
|
||||
config.class_num,
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid dataset : {:}".format(config.dataset))
|
||||
else:
|
||||
genotype = Networks[config.genotype]
|
||||
if config.dataset == 'cifar':
|
||||
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
|
||||
elif config.dataset == 'imagenet':
|
||||
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
|
||||
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
|
||||
else:
|
||||
raise ValueError('invalid nas arch type : {:}'.format(config.arch))
|
||||
raise ValueError("invalid nas arch type : {:}".format(config.arch))
|
||||
|
Reference in New Issue
Block a user