Add more algorithms
This commit is contained in:
16
lib/nas_infer_model/__init__.py
Normal file
16
lib/nas_infer_model/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
|
||||
def obtain_nas_infer_model(config):
|
||||
if config.arch == 'dxys':
|
||||
from .DXYs import CifarNet, ImageNet, Networks
|
||||
genotype = Networks[config.genotype]
|
||||
if config.dataset == 'cifar':
|
||||
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
|
||||
elif config.dataset == 'imagenet':
|
||||
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
|
||||
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
|
||||
else:
|
||||
raise ValueError('invalid nas arch type : {:}'.format(config.arch))
|
Reference in New Issue
Block a user