Add int search space
This commit is contained in:
@@ -53,7 +53,11 @@ def copy_data(source_dir, target_dir, meta_path):
|
||||
target_path = os.path.join(target_dir, file_name)
|
||||
if os.path.exists(source_path):
|
||||
s2t[source_path] = target_path
|
||||
print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t)))
|
||||
print(
|
||||
"Map from {:} to {:}, find {:} missed ckps.".format(
|
||||
source_dir, target_dir, len(s2t)
|
||||
)
|
||||
)
|
||||
for s, t in s2t.items():
|
||||
copyfile(s, t)
|
||||
|
||||
@@ -63,9 +67,18 @@ if __name__ == "__main__":
|
||||
description="NATS-Bench (topology search space) file manager.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.")
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log."
|
||||
"--mode",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["check", "copy"],
|
||||
help="The script mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="output/NATS-Bench-topology",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
|
||||
# use for train the model
|
||||
@@ -75,8 +88,13 @@ if __name__ == "__main__":
|
||||
if args.mode == "check":
|
||||
for config, possible_seeds in zip(possible_configs, possible_seedss):
|
||||
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
|
||||
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds)
|
||||
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config))
|
||||
seed2ckps, miss2ckps = obtain_valid_ckp(
|
||||
cur_save_dir, args.check_N, possible_seeds
|
||||
)
|
||||
torch.save(
|
||||
dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps),
|
||||
"{:}/meta-{:}.pth".format(args.save_dir, config),
|
||||
)
|
||||
elif args.mode == "copy":
|
||||
for config in possible_configs:
|
||||
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
|
||||
|
Reference in New Issue
Block a user