Upgrade API of NAS-Bench-201
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
# python exps/NAS-Bench-201/check.py --base_save_dir
|
||||
# python exps/NAS-Bench-201/check.py --base_str C16-N5-LESS
|
||||
#####################################################
|
||||
import sys, time, argparse, collections
|
||||
import torch
|
||||
@@ -13,10 +13,9 @@ from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
|
||||
|
||||
def check_files(save_dir, meta_file, basestr):
|
||||
meta_infos = torch.load(meta_file, map_location='cpu')
|
||||
meta_archs = meta_infos['archs']
|
||||
meta_infos = torch.load(meta_file, map_location='cpu')
|
||||
meta_archs = meta_infos['archs']
|
||||
meta_num_archs = meta_infos['total']
|
||||
meta_max_node = meta_infos['max_node']
|
||||
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
|
||||
|
||||
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
|
||||
@@ -43,7 +42,12 @@ def check_files(save_dir, meta_file, basestr):
|
||||
dir2ckps, dir2ckp_exists = dict(), dict()
|
||||
start_time, epoch_time = time.time(), AverageMeter()
|
||||
for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()):
|
||||
seeds = [777, 888, 999]
|
||||
if basestr == 'C16-N5':
|
||||
seeds = [777, 888, 999]
|
||||
elif basestr == 'C16-N5-LESS':
|
||||
seeds = [111, 777]
|
||||
else:
|
||||
raise ValueError('Invalid base str : {:}'.format(basestr))
|
||||
numrs = defaultdict(lambda: 0)
|
||||
all_checkpoints, all_ckp_exists = [], []
|
||||
for arch_index in arch_indexes:
|
||||
@@ -66,17 +70,15 @@ def check_files(save_dir, meta_file, basestr):
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
|
||||
parser.add_argument('--max_node', type=int, default=4, help='The maximum node in a cell.')
|
||||
parser.add_argument('--channel', type=int, default=16, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.')
|
||||
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
|
||||
parser.add_argument('--meta_path', type=str, default='./output/NAS-BENCH-201-4/meta-node-4.pth', help='The meta file path.')
|
||||
parser.add_argument('--base_str', type=str, default='C16-N5', help='The basic string.')
|
||||
args = parser.parse_args()
|
||||
|
||||
save_dir = Path( args.base_save_dir )
|
||||
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
|
||||
|
||||
save_dir = Path(args.base_save_dir)
|
||||
meta_path = Path(args.meta_path)
|
||||
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
|
||||
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
|
||||
print ('check NAS-Bench-201 in {:}'.format(save_dir))
|
||||
|
||||
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
|
||||
check_files(save_dir, meta_path, basestr)
|
||||
check_files(save_dir, meta_path, args.base_str)
|
||||
|
Reference in New Issue
Block a user