support first-order DARTS on the NASNet search space
This commit is contained in:
@@ -112,10 +112,14 @@ def main(xargs):
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
|
||||
if xargs.model_config is None:
|
||||
model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
|
||||
else:
|
||||
model_config = load_config(xargs.model_config, {'num_classes': class_num, 'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
logger.log('search-model :\n{:}'.format(search_model))
|
||||
|
||||
@@ -213,12 +217,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
# channels and number-of-cells
|
||||
parser.add_argument('--config_path', type=str, help='The config path.')
|
||||
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||
parser.add_argument('--config_path', type=str, help='The config path.')
|
||||
parser.add_argument('--model_config', type=str, help='The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
|
Reference in New Issue
Block a user