update NAS-Bench-102 baselines

This commit is contained in:
D-X-Y
2019-12-24 17:36:47 +11:00
parent af4212b4db
commit 44a0d51449
18 changed files with 105 additions and 110 deletions

View File

@@ -15,6 +15,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_102_api import NASBench102API as API
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger):
@@ -130,6 +131,9 @@ def main(xargs):
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
if xargs.arch_nas_dataset is None: api = None
else : api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
@@ -149,7 +153,7 @@ def main(xargs):
start_epoch, valid_accuracies = 0, {'best': -1}
# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
@@ -157,7 +161,8 @@ def main(xargs):
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy
@@ -188,7 +193,8 @@ def main(xargs):
start_time = time.time()
logger.log('\n' + '-'*200)
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
start_time = time.time()
best_arch, best_acc = None, -1
for iarch in range(xargs.select_num):
arch = search_model.random_genotype( True )
@@ -197,24 +203,10 @@ def main(xargs):
if best_arch is None or best_acc < valid_a_top1:
best_arch, best_acc = arch, valid_a_top1
logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc))
logger.log('\n' + '-'*100)
"""
# check the performance from the architecture dataset
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
else:
nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset)
geno = best_arch
logger.log('The last model is {:}'.format(geno))
info = nas_bench.query_by_arch( geno )
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
search_time.update(time.time() - start_time)
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
logger.close()
"""
if __name__ == '__main__':