Add int search space
This commit is contained in:
@@ -25,7 +25,11 @@ def check_unique_arch(meta_file):
|
||||
|
||||
def get_unique_matrix(archs, consider_zero):
|
||||
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
|
||||
print("{:} create unique-string ({:}/{:}) done".format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
|
||||
print(
|
||||
"{:} create unique-string ({:}/{:}) done".format(
|
||||
time_string(), len(set(UniquStrs)), len(UniquStrs)
|
||||
)
|
||||
)
|
||||
Unique2Index = dict()
|
||||
for index, xstr in enumerate(UniquStrs):
|
||||
if xstr not in Unique2Index:
|
||||
@@ -47,16 +51,32 @@ def check_unique_arch(meta_file):
|
||||
unique_num += 1
|
||||
return sm_matrix, unique_ids, unique_num
|
||||
|
||||
print("There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs)))
|
||||
print(
|
||||
"There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs))
|
||||
)
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
|
||||
print("{:} There are {:} unique architectures (considering nothing).".format(time_string(), unique_num))
|
||||
print(
|
||||
"{:} There are {:} unique architectures (considering nothing).".format(
|
||||
time_string(), unique_num
|
||||
)
|
||||
)
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
|
||||
print("{:} There are {:} unique architectures (not considering zero).".format(time_string(), unique_num))
|
||||
print(
|
||||
"{:} There are {:} unique architectures (not considering zero).".format(
|
||||
time_string(), unique_num
|
||||
)
|
||||
)
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True)
|
||||
print("{:} There are {:} unique architectures (considering zero).".format(time_string(), unique_num))
|
||||
print(
|
||||
"{:} There are {:} unique architectures (considering zero).".format(
|
||||
time_string(), unique_num
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
|
||||
def check_cor_for_bandit(
|
||||
meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False
|
||||
):
|
||||
if isinstance(meta_file, API):
|
||||
api = meta_file
|
||||
else:
|
||||
@@ -69,7 +89,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
|
||||
imagenet_test = []
|
||||
imagenet_valid = []
|
||||
for idx, arch in enumerate(api):
|
||||
results = api.get_more_info(idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand)
|
||||
results = api.get_more_info(
|
||||
idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand
|
||||
)
|
||||
cifar10_currs.append(results["valid-accuracy"])
|
||||
# --->>>>>
|
||||
results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand)
|
||||
@@ -89,13 +111,23 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
|
||||
cors = []
|
||||
for basestr, xlist in zip(
|
||||
["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"],
|
||||
[cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test],
|
||||
[
|
||||
cifar10_valid,
|
||||
cifar10_test,
|
||||
cifar100_valid,
|
||||
cifar100_test,
|
||||
imagenet_valid,
|
||||
imagenet_test,
|
||||
],
|
||||
):
|
||||
correlation = get_cor(cifar10_currs, xlist)
|
||||
if need_print:
|
||||
print(
|
||||
"With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format(
|
||||
test_epoch, "012" if use_less_or_not else "200", basestr, correlation
|
||||
test_epoch,
|
||||
"012" if use_less_or_not else "200",
|
||||
basestr,
|
||||
correlation,
|
||||
)
|
||||
)
|
||||
cors.append(correlation)
|
||||
@@ -113,7 +145,11 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
|
||||
# xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
|
||||
xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"]
|
||||
correlations = np.array(corrs)
|
||||
print("------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(test_epoch, "012" if use_less_or_not else "200"))
|
||||
print(
|
||||
"------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(
|
||||
test_epoch, "012" if use_less_or_not else "200"
|
||||
)
|
||||
)
|
||||
for idx, xstr in enumerate(xstrs):
|
||||
print(
|
||||
"{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format(
|
||||
@@ -135,7 +171,12 @@ if __name__ == "__main__":
|
||||
default="./output/search-cell-nas-bench-201/visuals",
|
||||
help="The base-name of folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.")
|
||||
parser.add_argument(
|
||||
"--api_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the NAS-Bench-201 benchmark file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
vis_save_dir = Path(args.save_dir)
|
||||
|
Reference in New Issue
Block a user