update NAS-Bench
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
######################################################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||
######################################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import sys, time, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
@@ -93,8 +93,7 @@ def get_best_arch(xloader, network, n_samples):
|
||||
_, logits = network(inputs)
|
||||
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
|
||||
|
||||
valid_accs.append( val_top1.item() )
|
||||
#print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1))
|
||||
valid_accs.append(val_top1.item())
|
||||
|
||||
best_idx = np.argmax(valid_accs)
|
||||
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
|
||||
@@ -142,10 +141,13 @@ def main(xargs):
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
|
||||
if xargs.model_config is None:
|
||||
model_config = dict2config(
|
||||
dict(name='SETN', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num,
|
||||
space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None)
|
||||
else:
|
||||
model_config = load_config(xargs.model_config, dict(num_classes=class_num, space=search_space, affine=False,
|
||||
track_running_stats=bool(xargs.track_running_stats)), None)
|
||||
logger.log('search space : {:}'.format(search_space))
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
|
||||
@@ -156,7 +158,6 @@ def main(xargs):
|
||||
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||
logger.log('criterion : {:}'.format(criterion))
|
||||
flop, param = get_model_infos(search_model, xshape)
|
||||
#logger.log('{:}'.format(search_model))
|
||||
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||
logger.log('search-space : {:}'.format(search_space))
|
||||
if xargs.arch_nas_dataset is None:
|
||||
@@ -233,7 +234,7 @@ def main(xargs):
|
||||
'last_checkpoint': save_path,
|
||||
}, logger.path('info'), logger)
|
||||
with torch.no_grad():
|
||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||
logger.log('{:}'.format(search_model.show_alphas()))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
|
Reference in New Issue
Block a user