Update REA, REINFORCE, and RANDOM

This commit is contained in:
D-X-Y
2020-07-13 10:04:52 +00:00
parent 041a9aa4a3
commit 6dc494be08
12 changed files with 277 additions and 53 deletions

View File

@@ -141,9 +141,12 @@ class NASBench201API(NASBenchMetaAPI):
# `is_random`
# When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index: int, dataset, iepoch=None, hp='12', is_random=True):
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:

View File

@@ -131,7 +131,7 @@ class NASBench301API(NASBenchMetaAPI):
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index: int, dataset: Text, iepoch=None, hp='12', is_random=True):
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
"""This function will return the metric for the `index`-th architecture
`dataset` indicates the dataset:
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
@@ -151,6 +151,9 @@ class NASBench301API(NASBenchMetaAPI):
"""
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:

View File

@@ -68,7 +68,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def reset_time(self):
self._used_time = 0
def simulate_train_eval(self, arch, dataset, hp='12'):
def simulate_train_eval(self, arch, dataset, hp='12', account_time=True):
index = self.query_index_by_arch(arch)
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
@@ -77,8 +77,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
else:
info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
self._used_time += time_cost
return valid_acc, time_cost, self._used_time
latency = self.get_latency(index, dataset)
if account_time:
self._used_time += time_cost
return valid_acc, latency, time_cost, self._used_time
def random(self):
"""Return a random index of all architectures."""