Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -35,9 +35,15 @@ def visualize_relative_info(api, vis_save_dir, indicator):
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
@@ -65,8 +71,15 @@ def visualize_relative_info(api, vis_save_dir, indicator):
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical")
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 3),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
@@ -102,7 +115,9 @@ def visualize_sss_info(api, dataset, vis_save_dir):
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False)
info = api.get_more_info(
index, "cifar10-valid", hp="90", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
@@ -263,7 +278,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False)
info = api.get_more_info(
index, "cifar10-valid", hp="200", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
@@ -288,7 +305,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
)
print("{:} collect data done.".format(time_string()))
resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"]
resnet = [
"|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
]
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
largest_indexes = [
api.query_index_by_arch(
@@ -415,9 +434,15 @@ def visualize_rank_info(api, vis_save_dir, indicator):
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
@@ -452,8 +477,17 @@ def visualize_rank_info(api, vis_save_dir, indicator):
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name))
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name))
ax.scatter(
[-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)
)
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="{:} validation".format(name),
)
ax.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
@@ -465,9 +499,13 @@ def visualize_rank_info(api, vis_save_dir, indicator):
labels = get_labels(imagenet_info)
plot_ax(labels, ax3, "ImageNet-16-120")
save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve()
save_path = (
vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve()
save_path = (
vis_save_dir / "{:}-same-relative-rank.png".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
@@ -496,9 +534,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
@@ -564,7 +608,9 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
ax1.set_title("Correlation coefficient over ALL candidates")
ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar))
ax2.set_title(
"Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)
)
save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
@@ -572,9 +618,14 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log."
"--save_dir",
type=str,
default="output/vis-nas-bench",
help="Folder to save checkpoints and log.",
)
# use for train the model
args = parser.parse_args()