simplify DARTS codes and update affine/track
This commit is contained in:
@@ -126,7 +126,8 @@ def main(xargs):
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space}, None)
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)
|
||||
@@ -222,6 +223,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
|
||||
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
|
Reference in New Issue
Block a user