Update REA, REINFORCE, and RANDOM
This commit is contained in:
@@ -131,7 +131,7 @@ class NASBench301API(NASBenchMetaAPI):
|
||||
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
|
||||
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||
|
||||
def get_more_info(self, index: int, dataset: Text, iepoch=None, hp='12', is_random=True):
|
||||
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
|
||||
"""This function will return the metric for the `index`-th architecture
|
||||
`dataset` indicates the dataset:
|
||||
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
@@ -151,6 +151,9 @@ class NASBench301API(NASBenchMetaAPI):
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
|
||||
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
||||
if index not in self.arch2infos_dict:
|
||||
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
|
||||
archresult = self.arch2infos_dict[index][str(hp)]
|
||||
# if randomly select one trial, select the seed at first
|
||||
if isinstance(is_random, bool) and is_random:
|
||||
|
Reference in New Issue
Block a user