update NAS-Bench-102 baselines / support track_running_stats
This commit is contained in:
@@ -135,6 +135,7 @@ def main(xargs):
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space}, None)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
logger.log('search-model :\n{:}'.format(search_model))
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||
@@ -211,10 +212,9 @@ def main(xargs):
|
||||
if find_best:
|
||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||
if api is not None:
|
||||
logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
with torch.no_grad():
|
||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
Reference in New Issue
Block a user