Add int search space
This commit is contained in:
@@ -18,12 +18,16 @@ def check_files(save_dir, meta_file, basestr):
|
||||
meta_infos = torch.load(meta_file, map_location="cpu")
|
||||
meta_archs = meta_infos["archs"]
|
||||
meta_num_archs = meta_infos["total"]
|
||||
assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format(
|
||||
meta_num_archs, len(meta_archs)
|
||||
)
|
||||
assert meta_num_archs == len(
|
||||
meta_archs
|
||||
), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs))
|
||||
|
||||
sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr))))
|
||||
print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs)))
|
||||
print(
|
||||
"{:} find {:} directories used to save checkpoints".format(
|
||||
time_string(), len(sub_model_dirs)
|
||||
)
|
||||
)
|
||||
|
||||
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
|
||||
num_seeds = defaultdict(lambda: 0)
|
||||
@@ -34,21 +38,29 @@ def check_files(save_dir, meta_file, basestr):
|
||||
for checkpoint in xcheckpoints:
|
||||
temp_names = checkpoint.name.split("-")
|
||||
assert (
|
||||
len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed"
|
||||
len(temp_names) == 4
|
||||
and temp_names[0] == "arch"
|
||||
and temp_names[2] == "seed"
|
||||
), "invalid checkpoint name : {:}".format(checkpoint.name)
|
||||
arch_indexes.add(temp_names[1])
|
||||
subdir2archs[sub_dir] = sorted(list(arch_indexes))
|
||||
num_evaluated_arch += len(arch_indexes)
|
||||
# count number of seeds for each architecture
|
||||
for arch_index in arch_indexes:
|
||||
num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1
|
||||
num_seeds[
|
||||
len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))
|
||||
] += 1
|
||||
print(
|
||||
"There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).".format(
|
||||
num_evaluated_arch, meta_num_archs, sum(k * v for k, v in num_seeds.items())
|
||||
)
|
||||
)
|
||||
for key in sorted(list(num_seeds.keys())):
|
||||
print("There are {:5d} architectures that are evaluated {:} times.".format(num_seeds[key], key))
|
||||
print(
|
||||
"There are {:5d} architectures that are evaluated {:} times.".format(
|
||||
num_seeds[key], key
|
||||
)
|
||||
)
|
||||
|
||||
dir2ckps, dir2ckp_exists = dict(), dict()
|
||||
start_time, epoch_time = time.time(), AverageMeter()
|
||||
@@ -62,12 +74,14 @@ def check_files(save_dir, meta_file, basestr):
|
||||
numrs = defaultdict(lambda: 0)
|
||||
all_checkpoints, all_ckp_exists = [], []
|
||||
for arch_index in arch_indexes:
|
||||
checkpoints = ["arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds]
|
||||
checkpoints = [
|
||||
"arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds
|
||||
]
|
||||
ckp_exists = [(sub_dir / x).exists() for x in checkpoints]
|
||||
arch_index = int(arch_index)
|
||||
assert 0 <= arch_index < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format(
|
||||
arch_index
|
||||
)
|
||||
assert (
|
||||
0 <= arch_index < len(meta_archs)
|
||||
), "invalid arch-index {:} (not found in meta_archs)".format(arch_index)
|
||||
all_checkpoints += checkpoints
|
||||
all_ckp_exists += ckp_exists
|
||||
numrs[sum(ckp_exists)] += 1
|
||||
@@ -76,7 +90,9 @@ def check_files(save_dir, meta_file, basestr):
|
||||
# measure time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
numrstr = ", ".join(["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())])
|
||||
numrstr = ", ".join(
|
||||
["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())]
|
||||
)
|
||||
print(
|
||||
"{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}".format(
|
||||
time_string(),
|
||||
@@ -95,7 +111,8 @@ def check_files(save_dir, meta_file, basestr):
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NAS Benchmark 201", formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
description="NAS Benchmark 201",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_save_dir",
|
||||
@@ -104,9 +121,14 @@ if __name__ == "__main__":
|
||||
help="The base-name of folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta_path", type=str, default="./output/NAS-BENCH-201-4/meta-node-4.pth", help="The meta file path."
|
||||
"--meta_path",
|
||||
type=str,
|
||||
default="./output/NAS-BENCH-201-4/meta-node-4.pth",
|
||||
help="The meta file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_str", type=str, default="C16-N5", help="The basic string."
|
||||
)
|
||||
parser.add_argument("--base_str", type=str, default="C16-N5", help="The basic string.")
|
||||
args = parser.parse_args()
|
||||
|
||||
save_dir = Path(args.base_save_dir)
|
||||
|
Reference in New Issue
Block a user