Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -25,7 +25,11 @@ def check_unique_arch(meta_file):
def get_unique_matrix(archs, consider_zero):
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
print("{:} create unique-string ({:}/{:}) done".format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
print(
"{:} create unique-string ({:}/{:}) done".format(
time_string(), len(set(UniquStrs)), len(UniquStrs)
)
)
Unique2Index = dict()
for index, xstr in enumerate(UniquStrs):
if xstr not in Unique2Index:
@@ -47,16 +51,32 @@ def check_unique_arch(meta_file):
unique_num += 1
return sm_matrix, unique_ids, unique_num
print("There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs)))
print(
"There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs))
)
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
print("{:} There are {:} unique architectures (considering nothing).".format(time_string(), unique_num))
print(
"{:} There are {:} unique architectures (considering nothing).".format(
time_string(), unique_num
)
)
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
print("{:} There are {:} unique architectures (not considering zero).".format(time_string(), unique_num))
print(
"{:} There are {:} unique architectures (not considering zero).".format(
time_string(), unique_num
)
)
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True)
print("{:} There are {:} unique architectures (considering zero).".format(time_string(), unique_num))
print(
"{:} There are {:} unique architectures (considering zero).".format(
time_string(), unique_num
)
)
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
def check_cor_for_bandit(
meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False
):
if isinstance(meta_file, API):
api = meta_file
else:
@@ -69,7 +89,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
imagenet_test = []
imagenet_valid = []
for idx, arch in enumerate(api):
results = api.get_more_info(idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand)
results = api.get_more_info(
idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand
)
cifar10_currs.append(results["valid-accuracy"])
# --->>>>>
results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand)
@@ -89,13 +111,23 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
cors = []
for basestr, xlist in zip(
["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"],
[cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test],
[
cifar10_valid,
cifar10_test,
cifar100_valid,
cifar100_test,
imagenet_valid,
imagenet_test,
],
):
correlation = get_cor(cifar10_currs, xlist)
if need_print:
print(
"With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format(
test_epoch, "012" if use_less_or_not else "200", basestr, correlation
test_epoch,
"012" if use_less_or_not else "200",
basestr,
correlation,
)
)
cors.append(correlation)
@@ -113,7 +145,11 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
# xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"]
correlations = np.array(corrs)
print("------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(test_epoch, "012" if use_less_or_not else "200"))
print(
"------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(
test_epoch, "012" if use_less_or_not else "200"
)
)
for idx, xstr in enumerate(xstrs):
print(
"{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format(
@@ -135,7 +171,12 @@ if __name__ == "__main__":
default="./output/search-cell-nas-bench-201/visuals",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.")
parser.add_argument(
"--api_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file.",
)
args = parser.parse_args()
vis_save_dir = Path(args.save_dir)