Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -30,15 +30,28 @@ from log_utils import time_string
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == "size"
if dataset == "cifar10":
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_acc = xinfo["test-accuracy"]
xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_acc = xinfo["valid-accuracy"]
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_acc = xinfo["valid-accuracy"]
test_acc = xinfo["test-accuracy"]
return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc)
return (
valid_acc,
test_acc,
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
)
def compute_kendalltau(vectori, vectorj):
@@ -61,9 +74,17 @@ if __name__ == "__main__":
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
choices=["tss", "sss"],
help="Choose the search space.",
)
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
args = parser.parse_args()
save_dir = Path(args.save_dir)
@@ -77,9 +98,17 @@ if __name__ == "__main__":
scores_1.append(valid_acc)
scores_2.append(test_acc)
correlation = compute_kendalltau(scores_1, scores_2)
print("The kendall tau correlation of {:} samples : {:}".format(len(indexes), correlation))
print(
"The kendall tau correlation of {:} samples : {:}".format(
len(indexes), correlation
)
)
correlation = compute_spearmanr(scores_1, scores_2)
print("The spearmanr correlation of {:} samples : {:}".format(len(indexes), correlation))
print(
"The spearmanr correlation of {:} samples : {:}".format(
len(indexes), correlation
)
)
# scores_1 = ['{:.2f}'.format(x) for x in scores_1]
# scores_2 = ['{:.2f}'.format(x) for x in scores_2]
# print(', '.join(scores_1))