update NAS-Bench-102 baselines
This commit is contained in:
@@ -17,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
|
||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||
@@ -162,7 +163,8 @@ def main(xargs):
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space}, None)
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
|
||||
logger.log('search space : {:}'.format(search_space))
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
|
||||
@@ -175,6 +177,12 @@ def main(xargs):
|
||||
flop, param = get_model_infos(search_model, xshape)
|
||||
#logger.log('{:}'.format(search_model))
|
||||
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||
logger.log('search-space : {:}'.format(search_space))
|
||||
if xargs.arch_nas_dataset is None:
|
||||
api = None
|
||||
else:
|
||||
api = API(xargs.arch_nas_dataset)
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||
@@ -196,7 +204,7 @@ def main(xargs):
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||
@@ -205,7 +213,8 @@ def main(xargs):
|
||||
|
||||
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
||||
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
|
||||
|
||||
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
|
||||
@@ -243,52 +252,23 @@ def main(xargs):
|
||||
}, logger.path('info'), logger)
|
||||
with torch.no_grad():
|
||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
#logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best']))
|
||||
# the final post procedure : count the time
|
||||
start_time = time.time()
|
||||
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
|
||||
search_time.update(time.time() - start_time)
|
||||
network.module.set_cal_mode('dynamic', genotype)
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
|
||||
# sampling
|
||||
"""
|
||||
with torch.no_grad():
|
||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||
selected_archs = set()
|
||||
while len(selected_archs) < xargs.select_num:
|
||||
architecture = search_model.dync_genotype()
|
||||
selected_archs.add( architecture )
|
||||
logger.log('select {:} architectures based on the learned arch-parameters'.format( len(selected_archs) ))
|
||||
|
||||
best_arch, best_acc = None, -1
|
||||
state_dict = deepcopy( network.state_dict() )
|
||||
for index, arch in enumerate(selected_archs):
|
||||
with torch.no_grad():
|
||||
search_model.set_cal_mode('dynamic', arch)
|
||||
network.load_state_dict( deepcopy(state_dict) )
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('{:} [{:03d}/{:03d}] : {:125s}, loss={:.3f}, accuracy={:.3f}%'.format(time_string(), index, len(selected_archs), str(arch), valid_a_loss , valid_a_top1))
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc))
|
||||
"""
|
||||
|
||||
logger.log('\n' + '-'*100)
|
||||
# check the performance from the architecture dataset
|
||||
"""
|
||||
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
else:
|
||||
nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||
geno = best_arch
|
||||
logger.log('The last model is {:}'.format(geno))
|
||||
info = nas_bench.query_by_arch( geno )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
"""
|
||||
logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotype) ))
|
||||
logger.close()
|
||||
|
||||
|
||||
@@ -303,7 +283,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
|
||||
parser.add_argument('--config_path', type=str, help='.')
|
||||
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||
parser.add_argument('--config_path', type=str, help='The path of the configuration.')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
|
Reference in New Issue
Block a user