Update VIS-CODES and SCRIPTS

This commit is contained in:
D-X-Y
2020-07-22 12:48:40 +00:00
parent 8d27050f6f
commit a2a1abcb7d
3 changed files with 55 additions and 21 deletions

View File

@@ -22,8 +22,8 @@
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
####
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
######################################################################################
import os, sys, time, random, argparse
import numpy as np
@@ -333,7 +333,11 @@ def main(xargs):
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
if xargs.overwite_epochs is None:
extra_info = {'class_num': class_num, 'xshape': xshape}
else:
extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs}
config = load_config(xargs.config_path, extra_info, logger)
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
(config.batch_size, config.test_batch_size), xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
@@ -496,6 +500,7 @@ if __name__ == '__main__':
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
@@ -508,8 +513,13 @@ if __name__ == '__main__':
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
args.dataset,
'{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate))
if args.overwite_epochs is None:
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
args.dataset,
'{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate))
else:
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
args.dataset,
'{:}-affine{:}_BN{:}-E{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.overwite_epochs, args.drop_path_rate))
main(args)