Reformulate via black

This commit is contained in:
D-X-Y
2021-03-17 09:25:58 +00:00
parent a9093e41e1
commit f98edea22a
59 changed files with 12289 additions and 8918 deletions

View File

@@ -9,72 +9,82 @@ import os, sys, time, torch, argparse
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from procedures import bench_evaluate_for_seed
from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
from procedures import bench_evaluate_for_seed
from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]):
seed2ckps = defaultdict(list)
miss2ckps = defaultdict(list)
for i in range(total):
for seed in possible_seeds:
path = os.path.join(save_dir, 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed))
if os.path.exists(path):
seed2ckps[seed].append(i)
else:
miss2ckps[seed].append(i)
for seed, xlist in seed2ckps.items():
print('[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}'.format(save_dir, seed, len(xlist), total, total-len(xlist), total))
return dict(seed2ckps), dict(miss2ckps)
seed2ckps = defaultdict(list)
miss2ckps = defaultdict(list)
for i in range(total):
for seed in possible_seeds:
path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed))
if os.path.exists(path):
seed2ckps[seed].append(i)
else:
miss2ckps[seed].append(i)
for seed, xlist in seed2ckps.items():
print(
"[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format(
save_dir, seed, len(xlist), total, total - len(xlist), total
)
)
return dict(seed2ckps), dict(miss2ckps)
def copy_data(source_dir, target_dir, meta_path):
target_dir = Path(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
miss2ckps = torch.load(meta_path)['miss2ckps']
s2t = {}
for seed, xlist in miss2ckps.items():
for i in xlist:
file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed)
source_path = os.path.join(source_dir, file_name)
target_path = os.path.join(target_dir, file_name)
if os.path.exists(source_path):
s2t[source_path] = target_path
print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t)))
for s, t in s2t.items():
copyfile(s, t)
target_dir = Path(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
miss2ckps = torch.load(meta_path)["miss2ckps"]
s2t = {}
for seed, xlist in miss2ckps.items():
for i in xlist:
file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed)
source_path = os.path.join(source_dir, file_name)
target_path = os.path.join(target_dir, file_name)
if os.path.exists(source_path):
s2t[source_path] = target_path
print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t)))
for s, t in s2t.items():
copyfile(s, t)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench (topology search space) file manager.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--mode', type=str, required=True, choices=['check', 'copy'], help='The script mode.')
parser.add_argument('--save_dir', type=str, default='output/NATS-Bench-topology', help='Folder to save checkpoints and log.')
parser.add_argument('--check_N', type=int, default=15625, help='For safety.')
# use for train the model
args = parser.parse_args()
possible_configs = ['12', '200']
possible_seedss = [[111, 777], [777, 888, 999]]
if args.mode == 'check':
for config, possible_seeds in zip(possible_configs, possible_seedss):
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds)
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config))
elif args.mode == 'copy':
for config in possible_configs:
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config)
cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config)
if os.path.exists(cur_meta_path):
copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
else:
print('Do not find : {:}'.format(cur_meta_path))
else:
raise ValueError('invalid mode : {:}'.format(args.mode))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (topology search space) file manager.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.")
parser.add_argument(
"--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log."
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
# use for train the model
args = parser.parse_args()
possible_configs = ["12", "200"]
possible_seedss = [[111, 777], [777, 888, 999]]
if args.mode == "check":
for config, possible_seeds in zip(possible_configs, possible_seedss):
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds)
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config))
elif args.mode == "copy":
for config in possible_configs:
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config)
cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config)
if os.path.exists(cur_meta_path):
copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
else:
print("Do not find : {:}".format(cur_meta_path))
else:
raise ValueError("invalid mode : {:}".format(args.mode))