Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201

This commit is contained in:
D-X-Y
2020-07-08 04:46:25 +00:00
parent 4892692622
commit 233a829bd7
11 changed files with 23 additions and 16 deletions

View File

@@ -260,7 +260,7 @@ def main(xargs):
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
@@ -268,7 +268,7 @@ def main(xargs):
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1]), '200'))
logger.close()