update NAS-Bench-102 baselines

This commit is contained in:
D-X-Y
2019-12-25 10:30:50 +11:00
parent 44a0d51449
commit 1d5e8debad
5 changed files with 48 additions and 17 deletions

View File

@@ -82,6 +82,16 @@ def valid_func(xloader, network, criterion):
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def search_find_best(valid_loader, network, criterion, select_num):
best_arch, best_acc = None, -1
for iarch in range(select_num):
arch = network.module.random_genotype( True )
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
if best_arch is None or best_acc < valid_a_top1:
best_arch, best_acc = arch, valid_a_top1
return best_arch
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
@@ -143,6 +153,7 @@ def main(xargs):
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
@@ -150,7 +161,7 @@ def main(xargs):
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies = 0, {'best': -1}
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
@@ -160,11 +171,14 @@ def main(xargs):
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
# selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
genotypes[epoch] = cur_arch
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies['best']:
@@ -178,6 +192,7 @@ def main(xargs):
'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
@@ -188,6 +203,7 @@ def main(xargs):
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
@@ -202,7 +218,6 @@ def main(xargs):
logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss))
if best_arch is None or best_acc < valid_a_top1:
best_arch, best_acc = arch, valid_a_top1
search_time.update(time.time() - start_time)
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))