Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
@@ -199,7 +199,7 @@ def main(xargs):
|
||||
if find_best:
|
||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
@@ -210,7 +210,7 @@ def main(xargs):
|
||||
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch, '200')))
|
||||
logger.close()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user