update NAS-Bench-102 baselines

This commit is contained in:
D-X-Y
2019-12-24 17:36:47 +11:00
parent af4212b4db
commit 44a0d51449
18 changed files with 105 additions and 110 deletions

View File

@@ -93,8 +93,8 @@ def main(xargs):
logger.log('Load split file from {:}'.format(split_Fpath))
else:
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
config_path = 'configs/nas-benchmark/algos/GDAS.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
#config_path = 'configs/nas-benchmark/algos/GDAS.config'
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
@@ -105,7 +105,7 @@ def main(xargs):
model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': True}, None)
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
@@ -156,7 +156,7 @@ def main(xargs):
search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 ))
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
@@ -210,6 +210,8 @@ if __name__ == '__main__':
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--config_path', type=str, help='The path of the configuration.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')