Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
@@ -53,7 +53,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12')
|
||||
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200')
|
||||
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
|
||||
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', is_random=True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
|
||||
cost = nas_bench.get_cost_info(arch_index, dataname, hp='200')
|
||||
# The following codes are used to estimate the time cost.
|
||||
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
|
||||
@@ -218,7 +218,7 @@ def main(xargs, nas_bench):
|
||||
best_arch = best_arch.arch
|
||||
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
||||
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
info = nas_bench.query_by_arch(best_arch, '200')
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
|
Reference in New Issue
Block a user