Add int search space
This commit is contained in:
@@ -30,12 +30,20 @@ from models import get_cell_based_tiny_net
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
|
||||
name2label = {
|
||||
"cifar10": "CIFAR-10",
|
||||
"cifar100": "CIFAR-100",
|
||||
"ImageNet16-120": "ImageNet-16-120",
|
||||
}
|
||||
|
||||
|
||||
def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
|
||||
vis_save_dir = vis_save_dir.resolve()
|
||||
print("{:} start to visualize {:} with top-{:} information".format(time_string(), search_space, topk))
|
||||
print(
|
||||
"{:} start to visualize {:} with top-{:} information".format(
|
||||
time_string(), search_space, topk
|
||||
)
|
||||
)
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_file_path = vis_save_dir / "cache-{:}-info.pth".format(search_space)
|
||||
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
|
||||
@@ -46,8 +54,12 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
|
||||
all_info = OrderedDict()
|
||||
for dataset in datasets:
|
||||
info_less = api.get_more_info(index, dataset, hp="12", is_random=False)
|
||||
info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False)
|
||||
all_info[dataset] = dict(less=info_less["test-accuracy"], more=info_more["test-accuracy"])
|
||||
info_more = api.get_more_info(
|
||||
index, dataset, hp=api.full_train_epochs, is_random=False
|
||||
)
|
||||
all_info[dataset] = dict(
|
||||
less=info_less["test-accuracy"], more=info_more["test-accuracy"]
|
||||
)
|
||||
all_infos[index] = all_info
|
||||
torch.save(all_infos, cache_file_path)
|
||||
print("{:} save all cache data into {:}".format(time_string(), cache_file_path))
|
||||
@@ -80,12 +92,18 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
|
||||
for idx in selected_indexes:
|
||||
standard_scores.append(
|
||||
api.get_more_info(
|
||||
idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=False
|
||||
idx,
|
||||
dataset,
|
||||
hp=api.full_train_epochs if indicator == "more" else "12",
|
||||
is_random=False,
|
||||
)["test-accuracy"]
|
||||
)
|
||||
random_scores.append(
|
||||
api.get_more_info(
|
||||
idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=True
|
||||
idx,
|
||||
dataset,
|
||||
hp=api.full_train_epochs if indicator == "more" else "12",
|
||||
is_random=True,
|
||||
)["test-accuracy"]
|
||||
)
|
||||
indexes = list(range(len(selected_indexes)))
|
||||
@@ -105,11 +123,28 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
|
||||
ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
|
||||
ax.scatter(indexes, random_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
|
||||
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Average Over Multi-Trials")
|
||||
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="Randomly Selected Trial")
|
||||
ax.scatter(
|
||||
[-1],
|
||||
[-1],
|
||||
marker="o",
|
||||
s=100,
|
||||
c="tab:blue",
|
||||
label="Average Over Multi-Trials",
|
||||
)
|
||||
ax.scatter(
|
||||
[-1],
|
||||
[-1],
|
||||
marker="^",
|
||||
s=100,
|
||||
c="tab:green",
|
||||
label="Randomly Selected Trial",
|
||||
)
|
||||
|
||||
coef, p = scipy.stats.kendalltau(standard_scores, random_scores)
|
||||
ax.set_xlabel("architecture ranking in {:}".format(name2label[dataset]), fontsize=LabelSize)
|
||||
ax.set_xlabel(
|
||||
"architecture ranking in {:}".format(name2label[dataset]),
|
||||
fontsize=LabelSize,
|
||||
)
|
||||
if dataset == "cifar10":
|
||||
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
|
||||
ax.legend(loc=4, fontsize=LegendFontsize)
|
||||
@@ -117,17 +152,27 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
|
||||
|
||||
for dataset, ax in zip(datasets, axs):
|
||||
rank_coef = sub_plot_fn(ax, dataset, indicator)
|
||||
print("sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(dataset, search_space, rank_coef))
|
||||
print(
|
||||
"sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(
|
||||
dataset, search_space, rank_coef
|
||||
)
|
||||
)
|
||||
|
||||
save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)).resolve()
|
||||
save_path = (
|
||||
vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)
|
||||
).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
|
||||
save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)).resolve()
|
||||
save_path = (
|
||||
vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)
|
||||
).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
|
||||
print("Save into {:}".format(save_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
|
Reference in New Issue
Block a user