update NAS-Bench-102 baselines
This commit is contained in:
@@ -15,6 +15,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
|
||||
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger):
|
||||
@@ -224,6 +225,12 @@ def main(xargs):
|
||||
#flop, param = get_model_infos(shared_cnn, xshape)
|
||||
#logger.log('{:}'.format(shared_cnn))
|
||||
#logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||
logger.log('search-space : {:}'.format(search_space))
|
||||
if xargs.arch_nas_dataset is None:
|
||||
api = None
|
||||
else:
|
||||
api = API(xargs.arch_nas_dataset)
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda()
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
@@ -247,7 +254,7 @@ def main(xargs):
|
||||
start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None
|
||||
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||
@@ -263,7 +270,8 @@ def main(xargs):
|
||||
'ctl_entropy_w': xargs.controller_entropy_weight,
|
||||
'ctl_bl_dec' : xargs.controller_bl_dec}, None), \
|
||||
epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline))
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum))
|
||||
best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
|
||||
shared_cnn.module.update_arch(best_arch)
|
||||
_, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)
|
||||
@@ -298,6 +306,7 @@ def main(xargs):
|
||||
if find_best:
|
||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc))
|
||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
@@ -306,27 +315,15 @@ def main(xargs):
|
||||
logger.log('During searching, the best architecture is {:}'.format(genotypes['best']))
|
||||
logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best']))
|
||||
logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples))
|
||||
start_time = time.time()
|
||||
final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples)
|
||||
search_time.update(time.time() - start_time)
|
||||
shared_cnn.module.update_arch(final_arch)
|
||||
final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
|
||||
logger.log('The Selected Final Architecture : {:}'.format(final_arch))
|
||||
logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5))
|
||||
# check the performance from the architecture dataset
|
||||
#if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
# logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
#else:
|
||||
# nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||
# geno = genotypes[total_epoch-1]
|
||||
# logger.log('The last model is {:}'.format(geno))
|
||||
# info = nas_bench.query_by_arch( geno )
|
||||
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
# else : logger.log('{:}'.format(info))
|
||||
# logger.log('-'*100)
|
||||
# geno = genotypes['best']
|
||||
# logger.log('The best model is {:}'.format(geno))
|
||||
# info = nas_bench.query_by_arch( geno )
|
||||
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
# else : logger.log('{:}'.format(info))
|
||||
logger.log('ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, final_arch))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(final_arch) ))
|
||||
logger.close()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user