Upgrade NAS-API to v2.0:

we use an abstract class NASBenchMetaAPI to define the spec of an API; it can be inherited to support different kinds of NAS API, while keep the query interface the same.
This commit is contained in:
D-X-Y
2020-06-30 09:05:38 +00:00
parent 91ee265bd2
commit 6effb6f127
23 changed files with 1888 additions and 944 deletions

View File

@@ -42,7 +42,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01
if use_012_epoch_training and nas_bench is not None:
arch_index = nas_bench.query_index_by_arch( arch )
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
info = nas_bench.get_more_info(arch_index, dataname, None, True)
info = nas_bench.get_more_info(arch_index, dataname, iepoch=None, hp='12', is_random=True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
elif not use_012_epoch_training and nas_bench is not None:
@@ -51,10 +51,10 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False)
info = nas_bench.get_more_info(arch_index, dataname, nepoch, False, True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
cost = nas_bench.get_cost_info(arch_index, dataname, False)
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12')
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200')
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
cost = nas_bench.get_cost_info(arch_index, dataname, hp='200')
# The following codes are used to estimate the time cost.
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
# When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.