simplify DARTS codes and update affine/track

This commit is contained in:
D-X-Y
2020-01-11 18:46:31 +11:00
parent c66afa4df8
commit 654015bf9d
15 changed files with 30 additions and 110 deletions

View File

@@ -114,7 +114,8 @@ def main(xargs):
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space}, None)
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
@@ -217,6 +218,7 @@ if __name__ == '__main__':
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')

View File

@@ -177,7 +177,8 @@ def main(xargs):
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space}, None)
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
@@ -282,6 +283,7 @@ if __name__ == '__main__':
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')

View File

@@ -198,7 +198,8 @@ def main(xargs):
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space}, None)
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
shared_cnn = get_cell_based_tiny_net(model_config)
controller = shared_cnn.create_controller()
@@ -319,6 +320,7 @@ if __name__ == '__main__':
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')

View File

@@ -126,7 +126,8 @@ def main(xargs):
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space}, None)
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config)
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)
@@ -222,6 +223,7 @@ if __name__ == '__main__':
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')