Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -30,12 +30,20 @@ from models import get_cell_based_tiny_net
from nats_bench import create
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} with top-{:} information".format(time_string(), search_space, topk))
print(
"{:} start to visualize {:} with top-{:} information".format(
time_string(), search_space, topk
)
)
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "cache-{:}-info.pth".format(search_space)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
@@ -46,8 +54,12 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
all_info = OrderedDict()
for dataset in datasets:
info_less = api.get_more_info(index, dataset, hp="12", is_random=False)
info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False)
all_info[dataset] = dict(less=info_less["test-accuracy"], more=info_more["test-accuracy"])
info_more = api.get_more_info(
index, dataset, hp=api.full_train_epochs, is_random=False
)
all_info[dataset] = dict(
less=info_less["test-accuracy"], more=info_more["test-accuracy"]
)
all_infos[index] = all_info
torch.save(all_infos, cache_file_path)
print("{:} save all cache data into {:}".format(time_string(), cache_file_path))
@@ -80,12 +92,18 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
for idx in selected_indexes:
standard_scores.append(
api.get_more_info(
idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=False
idx,
dataset,
hp=api.full_train_epochs if indicator == "more" else "12",
is_random=False,
)["test-accuracy"]
)
random_scores.append(
api.get_more_info(
idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=True
idx,
dataset,
hp=api.full_train_epochs if indicator == "more" else "12",
is_random=True,
)["test-accuracy"]
)
indexes = list(range(len(selected_indexes)))
@@ -105,11 +123,28 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, random_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Average Over Multi-Trials")
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="Randomly Selected Trial")
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="Average Over Multi-Trials",
)
ax.scatter(
[-1],
[-1],
marker="^",
s=100,
c="tab:green",
label="Randomly Selected Trial",
)
coef, p = scipy.stats.kendalltau(standard_scores, random_scores)
ax.set_xlabel("architecture ranking in {:}".format(name2label[dataset]), fontsize=LabelSize)
ax.set_xlabel(
"architecture ranking in {:}".format(name2label[dataset]),
fontsize=LabelSize,
)
if dataset == "cifar10":
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
ax.legend(loc=4, fontsize=LegendFontsize)
@@ -117,17 +152,27 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
for dataset, ax in zip(datasets, axs):
rank_coef = sub_plot_fn(ax, dataset, indicator)
print("sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(dataset, search_space, rank_coef))
print(
"sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(
dataset, search_space, rank_coef
)
)
save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)).resolve()
save_path = (
vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)).resolve()
save_path = (
vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("Save into {:}".format(save_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--save_dir",
type=str,