fix bugs in RANDOM-NAS and BOHB

This commit is contained in:
D-X-Y
2019-12-29 20:17:26 +11:00
parent 4c144b7437
commit f8f44bfb31
8 changed files with 469 additions and 67 deletions

View File

@@ -104,14 +104,19 @@ class NASBench102API(object):
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
return None
def query_by_index(self, arch_index, dataname, use_12epochs_result=False):
# query information with the training of 12 epochs or 200 epochs
# if dataname is None, return the ArchResults
# else, return a dict with all trials on that dataset (the key is the seed)
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
info = archInfo.query(dataname)
return info
if dataname is None: return archInfo
else:
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
info = archInfo.query(dataname)
return info
def query_meta_info_by_index(self, arch_index, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
@@ -266,7 +271,7 @@ class ArchResults(object):
def query(self, dataset, seed=None):
if seed is None:
x_seeds = self.dataset_seed[dataset]
return [self.all_results[ (dataset, seed) ] for seed in x_seeds]
return {seed: self.all_results[ (dataset, seed) ] for seed in x_seeds}
else:
return self.all_results[ (dataset, seed) ]