update NAS-Bench

This commit is contained in:
D-X-Y
2020-03-09 19:38:00 +11:00
parent 9a83814a46
commit e59eb804cb
35 changed files with 693 additions and 64 deletions

View File

@@ -3,7 +3,7 @@
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import os, sys, time, glob, random, argparse
import sys, time, random, argparse
import numpy as np
from copy import deepcopy
import torch
@@ -93,8 +93,7 @@ def get_best_arch(xloader, network, n_samples):
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
valid_accs.append( val_top1.item() )
#print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1))
valid_accs.append(val_top1.item())
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
@@ -142,10 +141,13 @@ def main(xargs):
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
if xargs.model_config is None:
model_config = dict2config(
dict(name='SETN', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num,
space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None)
else:
model_config = load_config(xargs.model_config, dict(num_classes=class_num, space=search_space, affine=False,
track_running_stats=bool(xargs.track_running_stats)), None)
logger.log('search space : {:}'.format(search_space))
search_model = get_cell_based_tiny_net(model_config)
@@ -156,7 +158,6 @@ def main(xargs):
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log('search-space : {:}'.format(search_space))
if xargs.arch_nas_dataset is None:
@@ -233,7 +234,7 @@ def main(xargs):
'last_checkpoint': save_path,
}, logger.path('info'), logger)
with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
# measure elapsed time
epoch_time.update(time.time() - start_time)