re-organize NATS-Bench
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
###############################################################
|
||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
###############################################################
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
|
||||
##############################################################################
|
||||
import os, sys, time, torch, argparse
|
||||
from typing import List, Text, Dict, Any
|
||||
from PIL import ImageFile
|
||||
@@ -189,9 +189,9 @@ def filter_indexes(xlist, mode, save_dir, seeds):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = argparse.ArgumentParser(description='NATS-Bench (size search space)', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode', type=str, required=True, choices=['new', 'cover'], help='The script mode.')
|
||||
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--save_dir', type=str, default='output/NATS-Bench-size', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--candidateC', type=int, nargs='+', default=[8, 16, 24, 32, 40, 48, 56, 64], help='.')
|
||||
parser.add_argument('--num_layers', type=int, default=5, help='The number of layers in a network.')
|
||||
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
|
||||
@@ -206,10 +206,12 @@ if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
|
||||
nets = traverse_net(args.candidateC, args.num_layers)
|
||||
if len(nets) != args.check_N: raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N))
|
||||
if len(nets) != args.check_N:
|
||||
raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N))
|
||||
|
||||
opt_config = './configs/nas-benchmark/hyper-opts/{:}E.config'.format(args.hyper)
|
||||
if not os.path.isfile(opt_config): raise ValueError('{:} is not a file.'.format(opt_config))
|
||||
if not os.path.isfile(opt_config):
|
||||
raise ValueError('{:} is not a file.'.format(opt_config))
|
||||
save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not isinstance(args.srange, str):
|
||||
@@ -218,7 +220,8 @@ if __name__ == '__main__':
|
||||
to_evaluate_indexes = set()
|
||||
for srange in srangestr.split(','):
|
||||
srange = srange.split('-')
|
||||
if len(srange) != 2: raise ValueError('invalid srange : {:}'.format(srange))
|
||||
if len(srange) != 2:
|
||||
raise ValueError('invalid srange : {:}'.format(srange))
|
||||
assert len(srange[0]) == len(srange[1]) == 5, 'invalid srange : {:}'.format(srange)
|
||||
srange = (int(srange[0]), int(srange[1]))
|
||||
if not (0 <= srange[0] <= srange[1] < args.check_N):
|
||||
@@ -226,10 +229,12 @@ if __name__ == '__main__':
|
||||
for i in range(srange[0], srange[1]+1):
|
||||
to_evaluate_indexes.add(i)
|
||||
|
||||
assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds)
|
||||
if not len(args.seeds):
|
||||
raise ValueError('invalid length of seeds args: {:}'.format(args.seeds))
|
||||
if not (len(args.datasets) == len(args.xpaths) == len(args.splits)):
|
||||
raise ValueError('invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits)))
|
||||
assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers)
|
||||
if args.workers <= 0:
|
||||
raise ValueError('invalid number of workers : {:}'.format(args.workers))
|
||||
|
||||
target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds)
|
||||
|
||||
@@ -239,4 +244,3 @@ if __name__ == '__main__':
|
||||
torch.set_num_threads(args.workers)
|
||||
|
||||
main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config, target_indexes, args.mode == 'cover')
|
||||
|
@@ -1,10 +1,10 @@
|
||||
###############################################################
|
||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 #
|
||||
###############################################################
|
||||
# Usage: python exps/NAS-Bench-201/xshape-file.py --mode check
|
||||
###############################################################
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 #
|
||||
##############################################################################
|
||||
# Usage: python exps/NATS-Bench/sss-file-manager.py --mode check #
|
||||
##############################################################################
|
||||
import os, sys, time, torch, argparse
|
||||
from typing import List, Text, Dict, Any
|
||||
from shutil import copyfile
|
||||
@@ -55,9 +55,9 @@ def copy_data(source_dir, target_dir, meta_path):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = argparse.ArgumentParser(description='NATS-Bench (size search space) file manager.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode', type=str, required=True, choices=['check', 'copy'], help='The script mode.')
|
||||
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--save_dir', type=str, default='output/NATS-Bench-size', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
|
||||
# use for train the model
|
||||
args = parser.parse_args()
|
Reference in New Issue
Block a user