Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -42,7 +42,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])]
info["time_w_arch"] = [
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
@@ -54,15 +56,28 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == "size"
if dataset == "cifar10":
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_acc = xinfo["test-accuracy"]
xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_acc = xinfo["valid-accuracy"]
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_acc = xinfo["valid-accuracy"]
test_acc = xinfo["test-accuracy"]
return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc)
return (
valid_acc,
test_acc,
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
)
def show_valid_test(api, arch):
@@ -84,8 +99,16 @@ def find_best_valid(api, dataset):
best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]
print("-" * 50 + "{:10s}".format(dataset) + "-" * 50)
print("Best ({:}) architecture on validation: {:}".format(best_valid_index, api[best_valid_index]))
print("Best ({:}) architecture on test: {:}".format(best_test_index, api[best_test_index]))
print(
"Best ({:}) architecture on validation: {:}".format(
best_valid_index, api[best_valid_index]
)
)
print(
"Best ({:}) architecture on test: {:}".format(
best_test_index, api[best_test_index]
)
)
_, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)
print("using validation ::: {:}".format(perf_str))
_, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)
@@ -130,10 +153,14 @@ def show_multi_trial(search_space):
v_acc, t_acc = query_performance(api, x, xdataset, float(max_time))
valid_accs.append(v_acc)
test_accs.append(t_acc)
valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs))
valid_str = "{:.2f}$\pm${:.2f}".format(
np.mean(valid_accs), np.std(valid_accs)
)
test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs))
print(
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(time_string(), alg, valid_str, test_str)
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(
time_string(), alg, valid_str, test_str
)
)
if search_space == "tss":