fix bugs in RANDOM-NAS and BOHB
This commit is contained in:
@@ -104,14 +104,19 @@ class NASBench102API(object):
|
||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
return None
|
||||
|
||||
def query_by_index(self, arch_index, dataname, use_12epochs_result=False):
|
||||
# query information with the training of 12 epochs or 200 epochs
|
||||
# if dataname is None, return the ArchResults
|
||||
# else, return a dict with all trials on that dataset (the key is the seed)
|
||||
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
||||
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
|
||||
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
||||
info = archInfo.query(dataname)
|
||||
return info
|
||||
if dataname is None: return archInfo
|
||||
else:
|
||||
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
||||
info = archInfo.query(dataname)
|
||||
return info
|
||||
|
||||
def query_meta_info_by_index(self, arch_index, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
@@ -266,7 +271,7 @@ class ArchResults(object):
|
||||
def query(self, dataset, seed=None):
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
return {seed: self.all_results[ (dataset, seed) ] for seed in x_seeds}
|
||||
else:
|
||||
return self.all_results[ (dataset, seed) ]
|
||||
|
||||
|
Reference in New Issue
Block a user