Support accumulate and reset time function for API
This commit is contained in:
@@ -61,6 +61,25 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
def avaliable_hps(self):
|
||||
return list(copy.deepcopy(self._avaliable_hps))
|
||||
|
||||
@property
|
||||
def used_time(self):
|
||||
return self._used_time
|
||||
|
||||
def reset_time(self):
|
||||
self._used_time = 0
|
||||
|
||||
def simulate_train_eval(self, arch, dataset, hp='12'):
|
||||
index = self.query_index_by_arch(arch)
|
||||
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
|
||||
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
|
||||
if dataset == 'cifar10':
|
||||
info = self.get_more_info(index, 'cifar10-valid', iepoch=None, hp=hp, is_random=True)
|
||||
else:
|
||||
info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True)
|
||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||
self._used_time += time_cost
|
||||
return valid_acc, time_cost, self._used_time
|
||||
|
||||
def random(self):
|
||||
"""Return a random index of all architectures."""
|
||||
return random.randint(0, len(self.meta_archs)-1)
|
||||
|
Reference in New Issue
Block a user