Upgrade NAS-API to v2.0:
we use an abstract class NASBenchMetaAPI to define the spec of an API; it can be inherited to support different kinds of NAS API, while keep the query interface the same.
This commit is contained in:
@@ -65,7 +65,7 @@ class MyWorker(Worker):
|
||||
assert len(self.seen_archs) > 0
|
||||
best_index, best_acc = -1, None
|
||||
for arch_index in self.seen_archs:
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True)
|
||||
vacc = info['valid-accuracy']
|
||||
if best_acc is None or best_acc < vacc:
|
||||
best_acc = vacc
|
||||
@@ -77,7 +77,7 @@ class MyWorker(Worker):
|
||||
start_time = time.time()
|
||||
structure = self.convert_func( config )
|
||||
arch_index = self._nas_bench.query_index_by_arch( structure )
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True)
|
||||
cur_time = info['train-all-time'] + info['valid-per-time']
|
||||
cur_vacc = info['valid-accuracy']
|
||||
self.real_cost_time += (time.time() - start_time)
|
||||
|
@@ -42,7 +42,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01
|
||||
if use_012_epoch_training and nas_bench is not None:
|
||||
arch_index = nas_bench.query_index_by_arch( arch )
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
info = nas_bench.get_more_info(arch_index, dataname, None, True)
|
||||
info = nas_bench.get_more_info(arch_index, dataname, iepoch=None, hp='12', is_random=True)
|
||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
|
||||
elif not use_012_epoch_training and nas_bench is not None:
|
||||
@@ -51,10 +51,10 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01
|
||||
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
|
||||
arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False)
|
||||
info = nas_bench.get_more_info(arch_index, dataname, nepoch, False, True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
|
||||
cost = nas_bench.get_cost_info(arch_index, dataname, False)
|
||||
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12')
|
||||
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200')
|
||||
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
|
||||
cost = nas_bench.get_cost_info(arch_index, dataname, hp='200')
|
||||
# The following codes are used to estimate the time cost.
|
||||
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
|
||||
# When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.
|
||||
|
Reference in New Issue
Block a user