rm PD ; update NAS-Bench-102 baselines
This commit is contained in:
@@ -17,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
|
||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||
@@ -144,6 +145,11 @@ def main(xargs):
|
||||
flop, param = get_model_infos(search_model, xshape)
|
||||
#logger.log('{:}'.format(search_model))
|
||||
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||
if xargs.arch_nas_dataset is None:
|
||||
api = None
|
||||
else:
|
||||
api = API(xargs.arch_nas_dataset)
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||
@@ -165,7 +171,7 @@ def main(xargs):
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||
@@ -173,7 +179,8 @@ def main(xargs):
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||
|
||||
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
# check the best accuracy
|
||||
@@ -204,6 +211,8 @@ def main(xargs):
|
||||
if find_best:
|
||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||
if api is not None:
|
||||
logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
with torch.no_grad():
|
||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||
# measure elapsed time
|
||||
@@ -211,22 +220,8 @@ def main(xargs):
|
||||
start_time = time.time()
|
||||
|
||||
logger.log('\n' + '-'*100)
|
||||
# check the performance from the architecture dataset
|
||||
#if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
# logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
#else:
|
||||
# nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||
# geno = genotypes[total_epoch-1]
|
||||
# logger.log('The last model is {:}'.format(geno))
|
||||
# info = nas_bench.query_by_arch( geno )
|
||||
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
# else : logger.log('{:}'.format(info))
|
||||
# logger.log('-'*100)
|
||||
# geno = genotypes['best']
|
||||
# logger.log('The best model is {:}'.format(geno))
|
||||
# info = nas_bench.query_by_arch( geno )
|
||||
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
# else : logger.log('{:}'.format(info))
|
||||
logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
|
||||
logger.close()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user