Add int search space
This commit is contained in:
@@ -42,7 +42,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
|
||||
for alg, path in alg2path.items():
|
||||
data = torch.load(path)
|
||||
for index, info in data.items():
|
||||
info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])]
|
||||
info["time_w_arch"] = [
|
||||
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
|
||||
]
|
||||
for j, arch in enumerate(info["all_archs"]):
|
||||
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
|
||||
alg, search_space, dataset, index, j
|
||||
@@ -54,15 +56,28 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
|
||||
def get_valid_test_acc(api, arch, dataset):
|
||||
is_size_space = api.search_space_name == "size"
|
||||
if dataset == "cifar10":
|
||||
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
xinfo = api.get_more_info(
|
||||
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
|
||||
)
|
||||
test_acc = xinfo["test-accuracy"]
|
||||
xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False)
|
||||
xinfo = api.get_more_info(
|
||||
arch,
|
||||
dataset="cifar10-valid",
|
||||
hp=90 if is_size_space else 200,
|
||||
is_random=False,
|
||||
)
|
||||
valid_acc = xinfo["valid-accuracy"]
|
||||
else:
|
||||
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
xinfo = api.get_more_info(
|
||||
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
|
||||
)
|
||||
valid_acc = xinfo["valid-accuracy"]
|
||||
test_acc = xinfo["test-accuracy"]
|
||||
return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc)
|
||||
return (
|
||||
valid_acc,
|
||||
test_acc,
|
||||
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
|
||||
)
|
||||
|
||||
|
||||
def show_valid_test(api, arch):
|
||||
@@ -84,8 +99,16 @@ def find_best_valid(api, dataset):
|
||||
best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]
|
||||
|
||||
print("-" * 50 + "{:10s}".format(dataset) + "-" * 50)
|
||||
print("Best ({:}) architecture on validation: {:}".format(best_valid_index, api[best_valid_index]))
|
||||
print("Best ({:}) architecture on test: {:}".format(best_test_index, api[best_test_index]))
|
||||
print(
|
||||
"Best ({:}) architecture on validation: {:}".format(
|
||||
best_valid_index, api[best_valid_index]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Best ({:}) architecture on test: {:}".format(
|
||||
best_test_index, api[best_test_index]
|
||||
)
|
||||
)
|
||||
_, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)
|
||||
print("using validation ::: {:}".format(perf_str))
|
||||
_, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)
|
||||
@@ -130,10 +153,14 @@ def show_multi_trial(search_space):
|
||||
v_acc, t_acc = query_performance(api, x, xdataset, float(max_time))
|
||||
valid_accs.append(v_acc)
|
||||
test_accs.append(t_acc)
|
||||
valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs))
|
||||
valid_str = "{:.2f}$\pm${:.2f}".format(
|
||||
np.mean(valid_accs), np.std(valid_accs)
|
||||
)
|
||||
test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs))
|
||||
print(
|
||||
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(time_string(), alg, valid_str, test_str)
|
||||
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(
|
||||
time_string(), alg, valid_str, test_str
|
||||
)
|
||||
)
|
||||
|
||||
if search_space == "tss":
|
||||
|
Reference in New Issue
Block a user