Add int search space
This commit is contained in:
@@ -35,9 +35,15 @@ def visualize_relative_info(api, vis_save_dir, indicator):
|
||||
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar10", indicator
|
||||
)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar100", indicator
|
||||
)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"ImageNet16-120", indicator
|
||||
)
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
@@ -65,8 +71,15 @@ def visualize_relative_info(api, vis_save_dir, indicator):
|
||||
plt.xlim(min(indexes), max(indexes))
|
||||
plt.ylim(min(indexes), max(indexes))
|
||||
# plt.ylabel('y').set_rotation(30)
|
||||
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical")
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
|
||||
plt.yticks(
|
||||
np.arange(min(indexes), max(indexes), max(indexes) // 3),
|
||||
fontsize=LegendFontsize,
|
||||
rotation="vertical",
|
||||
)
|
||||
plt.xticks(
|
||||
np.arange(min(indexes), max(indexes), max(indexes) // 5),
|
||||
fontsize=LegendFontsize,
|
||||
)
|
||||
ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
|
||||
ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
|
||||
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
|
||||
@@ -102,7 +115,9 @@ def visualize_sss_info(api, dataset, vis_save_dir):
|
||||
train_accs.append(info["train-accuracy"])
|
||||
test_accs.append(info["test-accuracy"])
|
||||
if dataset == "cifar10":
|
||||
info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False)
|
||||
info = api.get_more_info(
|
||||
index, "cifar10-valid", hp="90", is_random=False
|
||||
)
|
||||
valid_accs.append(info["valid-accuracy"])
|
||||
else:
|
||||
valid_accs.append(info["valid-accuracy"])
|
||||
@@ -263,7 +278,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
|
||||
train_accs.append(info["train-accuracy"])
|
||||
test_accs.append(info["test-accuracy"])
|
||||
if dataset == "cifar10":
|
||||
info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False)
|
||||
info = api.get_more_info(
|
||||
index, "cifar10-valid", hp="200", is_random=False
|
||||
)
|
||||
valid_accs.append(info["valid-accuracy"])
|
||||
else:
|
||||
valid_accs.append(info["valid-accuracy"])
|
||||
@@ -288,7 +305,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
|
||||
)
|
||||
print("{:} collect data done.".format(time_string()))
|
||||
|
||||
resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"]
|
||||
resnet = [
|
||||
"|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
|
||||
]
|
||||
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
|
||||
largest_indexes = [
|
||||
api.query_index_by_arch(
|
||||
@@ -415,9 +434,15 @@ def visualize_rank_info(api, vis_save_dir, indicator):
|
||||
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar10", indicator
|
||||
)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar100", indicator
|
||||
)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"ImageNet16-120", indicator
|
||||
)
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
@@ -452,8 +477,17 @@ def visualize_rank_info(api, vis_save_dir, indicator):
|
||||
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
|
||||
ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
|
||||
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name))
|
||||
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name))
|
||||
ax.scatter(
|
||||
[-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)
|
||||
)
|
||||
ax.scatter(
|
||||
[-1],
|
||||
[-1],
|
||||
marker="o",
|
||||
s=100,
|
||||
c="tab:blue",
|
||||
label="{:} validation".format(name),
|
||||
)
|
||||
ax.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
|
||||
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
|
||||
@@ -465,9 +499,13 @@ def visualize_rank_info(api, vis_save_dir, indicator):
|
||||
labels = get_labels(imagenet_info)
|
||||
plot_ax(labels, ax3, "ImageNet-16-120")
|
||||
|
||||
save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve()
|
||||
save_path = (
|
||||
vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)
|
||||
).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
|
||||
save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve()
|
||||
save_path = (
|
||||
vis_save_dir / "{:}-same-relative-rank.png".format(indicator)
|
||||
).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
|
||||
print("{:} save into {:}".format(time_string(), save_path))
|
||||
plt.close("all")
|
||||
@@ -496,9 +534,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
|
||||
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar10", indicator
|
||||
)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar100", indicator
|
||||
)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"ImageNet16-120", indicator
|
||||
)
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
@@ -564,7 +608,9 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
|
||||
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
|
||||
)
|
||||
ax1.set_title("Correlation coefficient over ALL candidates")
|
||||
ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar))
|
||||
ax2.set_title(
|
||||
"Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)
|
||||
)
|
||||
save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
|
||||
print("{:} save into {:}".format(time_string(), save_path))
|
||||
@@ -572,9 +618,14 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log."
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="output/vis-nas-bench",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
# use for train the model
|
||||
args = parser.parse_args()
|
||||
|
Reference in New Issue
Block a user