support get_net_config for NAS-Bench-201

This commit is contained in:
D-X-Y
2020-02-02 17:20:38 +11:00
parent 133fd21ecc
commit 25e529f788
3 changed files with 69 additions and 9 deletions

View File

@@ -93,6 +93,8 @@ class NASBench201API(object):
else: arch_index = -1
return arch_index
# Overwrite all information of the 'index'-th architecture in the search space.
# It will load its data from 'archive_root'.
def reload(self, archive_root, index):
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
@@ -123,9 +125,18 @@ class NASBench201API(object):
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
return None
# query information with the training of 12 epochs or 200 epochs
# if dataname is None, return the ArchResults
# This 'query_by_index' function is used to query information with the training of 12 epochs or 200 epochs.
# ------
# If use_12epochs_result=True, we train the model by 12 epochs (see config in configs/nas-benchmark/LESS.config)
# If use_12epochs_result=False, we train the model by 200 epochs (see config in configs/nas-benchmark/CIFAR.config)
# ------
# If dataname is None, return the ArchResults
# else, return a dict with all trials on that dataset (the key is the seed)
# Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
# -- cifar10-valid : training the model on the CIFAR-10 training set.
# -- cifar10 : training the model on the CIFAR-10 training + validation set.
# -- cifar100 : training the model on the CIFAR-100 training set.
# -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
@@ -166,12 +177,40 @@ class NASBench201API(object):
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index])
# obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
"""
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
Args [seed]:
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
-- a interger : return the weights of a specific trial, whose seed is this interger.
Args [use_12epochs_result]:
-- True : train the model by 12 epochs
-- False : train the model by 200 epochs
"""
def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
return archresult.get_net_param(dataset, seed)
"""
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
This function will return a dict.
========= Some examlpes for using this function:
config = api.get_net_config(128, 'cifar10')
"""
def get_net_config(self, index, dataset):
archresult = self.arch2infos_full[index]
all_results = archresult.query(dataset, None)
if len(all_results) == 0: raise ValueError('can not find one valid trial for the {:}-th architecture on {:}'.format(index, dataset))
for seed, result in all_results.items():
return result.get_config(None)
#print ('SEED [{:}] : {:}'.format(seed, result))
raise ValueError('Impossible to reach here!')
# obtain the cost metric for the `index`-th architecture on a dataset
def get_cost_info(self, index, dataset, use_12epochs_result=False):
@@ -333,6 +372,7 @@ class NASBench201API(object):
class ArchResults(object):
def __init__(self, arch_index, arch_str):
@@ -615,11 +655,16 @@ class ResultsCount(object):
def get_net_param(self):
return self.net_state_dict
# This function is used to obtain the config dict for this architecture.
def get_config(self, str2structure):
#return copy.deepcopy(self.arch_config)
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \
'N' : self.arch_config['num_cells'], \
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
if str2structure is None:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \
'N' : self.arch_config['num_cells'], \
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
else:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \
'N' : self.arch_config['num_cells'], \
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
def state_dict(self):
_state_dict = {key: value for key, value in self.__dict__.items()}