Add int search space
This commit is contained in:
@@ -30,15 +30,28 @@ from log_utils import time_string
|
||||
def get_valid_test_acc(api, arch, dataset):
|
||||
is_size_space = api.search_space_name == "size"
|
||||
if dataset == "cifar10":
|
||||
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
xinfo = api.get_more_info(
|
||||
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
|
||||
)
|
||||
test_acc = xinfo["test-accuracy"]
|
||||
xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False)
|
||||
xinfo = api.get_more_info(
|
||||
arch,
|
||||
dataset="cifar10-valid",
|
||||
hp=90 if is_size_space else 200,
|
||||
is_random=False,
|
||||
)
|
||||
valid_acc = xinfo["valid-accuracy"]
|
||||
else:
|
||||
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
xinfo = api.get_more_info(
|
||||
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
|
||||
)
|
||||
valid_acc = xinfo["valid-accuracy"]
|
||||
test_acc = xinfo["test-accuracy"]
|
||||
return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc)
|
||||
return (
|
||||
valid_acc,
|
||||
test_acc,
|
||||
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
|
||||
)
|
||||
|
||||
|
||||
def compute_kendalltau(vectori, vectorj):
|
||||
@@ -61,9 +74,17 @@ if __name__ == "__main__":
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="output/vis-nas-bench/nas-algos",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
choices=["tss", "sss"],
|
||||
help="Choose the search space.",
|
||||
)
|
||||
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
|
||||
args = parser.parse_args()
|
||||
|
||||
save_dir = Path(args.save_dir)
|
||||
@@ -77,9 +98,17 @@ if __name__ == "__main__":
|
||||
scores_1.append(valid_acc)
|
||||
scores_2.append(test_acc)
|
||||
correlation = compute_kendalltau(scores_1, scores_2)
|
||||
print("The kendall tau correlation of {:} samples : {:}".format(len(indexes), correlation))
|
||||
print(
|
||||
"The kendall tau correlation of {:} samples : {:}".format(
|
||||
len(indexes), correlation
|
||||
)
|
||||
)
|
||||
correlation = compute_spearmanr(scores_1, scores_2)
|
||||
print("The spearmanr correlation of {:} samples : {:}".format(len(indexes), correlation))
|
||||
print(
|
||||
"The spearmanr correlation of {:} samples : {:}".format(
|
||||
len(indexes), correlation
|
||||
)
|
||||
)
|
||||
# scores_1 = ['{:.2f}'.format(x) for x in scores_1]
|
||||
# scores_2 = ['{:.2f}'.format(x) for x in scores_2]
|
||||
# print(', '.join(scores_1))
|
||||
|
Reference in New Issue
Block a user