Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -28,15 +28,25 @@ from utils import weight_watcher
if __name__ == "__main__":
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument(
"--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir."
"--api_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file and weight dir.",
)
parser.add_argument(
"--archive_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 weight dir.",
)
parser.add_argument("--archive_path", type=str, default=None, help="The path to the NAS-Bench-201 weight dir.")
args = parser.parse_args()
meta_file = Path(args.api_path)
weight_dir = Path(args.archive_path)
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir)
assert (
weight_dir.exists() and weight_dir.is_dir()
), "invalid path for weight dir : {:}".format(weight_dir)
api = NASBench201API(meta_file, verbose=True)
@@ -46,7 +56,9 @@ if __name__ == "__main__":
data = "cifar10" # query the info from CIFAR-10
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(arch_index, hp="200") # all info about this architecture
meta_info = api.query_meta_info_by_index(
arch_index, hp="200"
) # all info about this architecture
params = meta_info.get_net_param(data, 888)
net.load_state_dict(params)

View File

@@ -69,7 +69,13 @@ def plot(filename):
for xin in range(i):
op_i = random.randint(0, len(OPS) - 1)
# g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
g.edge(
str(xin),
str(i),
label=OPS[op_i],
color=COLORS[op_i],
fillcolor=COLORS[op_i],
)
# import pdb; pdb.set_trace()
g.render(filename, cleanup=True, view=False)
@@ -88,7 +94,9 @@ def test_auto_grad():
net = Net(10)
inputs = torch.rand(256, 10)
loss = net(inputs)
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
first_order_grads = torch.autograd.grad(
loss, net.parameters(), retain_graph=True, create_graph=True
)
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
second_order_grads = []
for grads in first_order_grads:
@@ -108,9 +116,15 @@ def test_one_shot_model(ckpath, use_train):
print("ckpath : {:}".format(ckpath))
ckp = torch.load(ckpath)
xargs = ckp["args"]
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
# config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None)
config = load_config(
"./configs/nas-benchmark/algos/DARTS.config",
{"class_num": class_num, "xshape": xshape},
None,
)
if xargs.dataset == "cifar10":
cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
xvalid_data = deepcopy(train_data)
@@ -142,7 +156,9 @@ def test_one_shot_model(ckpath, use_train):
search_model.load_state_dict(ckp["search_model"])
search_model = search_model.cuda()
api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
archs, probs, accuracies = evaluate_one_shot(
search_model, valid_loader, api, use_train
)
if __name__ == "__main__":

View File

@@ -53,8 +53,12 @@ def evaluate(api, weight_dir, data: str):
# compute the weight watcher results
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(arch_index, hp="200" if api.search_space_name == "topology" else "90")
params = meta_info.get_net_param(data, 888 if api.search_space_name == "topology" else 777)
meta_info = api.query_meta_info_by_index(
arch_index, hp="200" if api.search_space_name == "topology" else "90"
)
params = meta_info.get_net_param(
data, 888 if api.search_space_name == "topology" else 777
)
with torch.no_grad():
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)
@@ -73,7 +77,10 @@ def evaluate(api, weight_dir, data: str):
norms.append(cur_norm)
# query the accuracy
info = meta_info.get_metrics(
data, "ori-test", iepoch=None, is_random=888 if api.search_space_name == "topology" else 777
data,
"ori-test",
iepoch=None,
is_random=888 if api.search_space_name == "topology" else 777,
)
accuracies.append(info["accuracy"])
del net, meta_info
@@ -98,7 +105,11 @@ def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
for hp in hps:
nums = api.statistics(data, hp=hp)
total = sum([k * v for k, v in nums.items()])
print("Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(hp, data, total, nums))
print(
"Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(
hp, data, total, nums
)
)
print(time_string() + " " + "=" * 50)
norms, accuracies = evaluate(api, weight_dir, xdata)
@@ -120,8 +131,15 @@ def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical")
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 3),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
ax.scatter(indexes, labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Test accuracy")
@@ -129,7 +147,9 @@ def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel("architecture ranking sorted by the test accuracy ", fontsize=LabelSize)
ax.set_xlabel(
"architecture ranking sorted by the test accuracy ", fontsize=LabelSize
)
ax.set_ylabel("architecture ranking computed by weight watcher", fontsize=LabelSize)
save_path = (save_dir / "{:}-{:}-test-ww.pdf".format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
@@ -148,9 +168,18 @@ if __name__ == "__main__":
default="./output/vis-nas-bench/",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--search_space", type=str, default=None, choices=["tss", "sss"], help="The search space.")
parser.add_argument(
"--base_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir."
"--search_space",
type=str,
default=None,
choices=["tss", "sss"],
help="The search space.",
)
parser.add_argument(
"--base_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file and weight dir.",
)
parser.add_argument("--dataset", type=str, default=None, help=".")
args = parser.parse_args()
@@ -160,6 +189,8 @@ if __name__ == "__main__":
meta_file = Path(args.base_path + ".pth")
weight_dir = Path(args.base_path + "-full")
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir)
assert (
weight_dir.exists() and weight_dir.is_dir()
), "invalid path for weight dir : {:}".format(weight_dir)
main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset)

View File

@@ -42,7 +42,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])]
info["time_w_arch"] = [
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
@@ -57,12 +59,16 @@ def query_performance(api, data, dataset, ticket):
time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
info_a = api.get_more_info(
arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
info_b = api.get_more_info(
arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"]
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / (
time_b - time_a
) * accuracy_b
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (
ticket - time_a
) / (time_b - time_a) * accuracy_b
results.append(interplate)
return sum(results) / len(results)
@@ -85,7 +91,11 @@ y_max_s = {
("ImageNet16-120", "sss"): 46,
}
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_curve(api, vis_save_dir, search_space, max_time):
@@ -100,7 +110,9 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
total_tickets = 150
time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)]
time_tickets = [
float(i) / total_tickets * max_time for i in range(total_tickets)
]
colors = ["b", "g", "c", "m", "y"]
ax.set_xlim(0, 200)
ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
@@ -111,10 +123,20 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
accuracy = query_performance(api, data, dataset, ticket)
accuracies.append(accuracy)
alg2accuracies[alg] = accuracies
ax.plot([x / 100 for x in time_tickets], accuracies, c=colors[idx], label="{:}".format(alg))
ax.plot(
[x / 100 for x in time_tickets],
accuracies,
c=colors[idx],
label="{:}".format(alg),
)
ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize)
ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize)
ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4)
ax.set_ylabel(
"Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize
)
ax.set_title(
"Searching results on {:}".format(name2label[dataset]),
fontsize=LabelSize + 4,
)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
@@ -129,12 +151,25 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
parser = argparse.ArgumentParser(
description="NAS-Bench-X",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
choices=["tss", "sss"],
help="Choose the search space.",
)
parser.add_argument(
"--max_time", type=float, default=20000, help="The maximum time budget."
)
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
parser.add_argument("--max_time", type=float, default=20000, help="The maximum time budget.")
args = parser.parse_args()
save_dir = Path(args.save_dir)

View File

@@ -29,7 +29,9 @@ from log_utils import time_string
# def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'):
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"):
def fetch_data(
root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"
):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999]
@@ -45,8 +47,12 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suf
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["masking + Gumbel-Softmax"] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(
suffix
)
alg2name[
"masking + Gumbel-Softmax"
] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth")
@@ -86,7 +92,11 @@ y_max_s = {
("ImageNet16-120", "sss"): 46,
}
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_curve(api, vis_save_dir, search_space):
@@ -111,10 +121,17 @@ def visualize_curve(api, vis_save_dir, search_space):
try:
structures, accs = [_[iepoch - 1] for _ in data], []
except:
raise ValueError("This alg {:} on {:} has invalid checkpoints.".format(alg, dataset))
raise ValueError(
"This alg {:} on {:} has invalid checkpoints.".format(
alg, dataset
)
)
for structure in structures:
info = api.get_more_info(
structure, dataset=dataset, hp=90 if api.search_space_name == "size" else 200, is_random=False
structure,
dataset=dataset,
hp=90 if api.search_space_name == "size" else 200,
is_random=False,
)
accs.append(info["test-accuracy"])
accuracies.append(sum(accs) / len(accs))
@@ -122,8 +139,13 @@ def visualize_curve(api, vis_save_dir, search_space):
alg2accuracies[alg] = accuracies
ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg))
ax.set_xlabel("The searching epoch", fontsize=LabelSize)
ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize)
ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4)
ax.set_ylabel(
"Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize
)
ax.set_title(
"Searching results on {:}".format(name2label[dataset]),
fontsize=LabelSize + 4,
)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
@@ -138,12 +160,22 @@ def visualize_curve(api, vis_save_dir, search_space):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
parser = argparse.ArgumentParser(
description="NAS-Bench-X",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--search_space", type=str, default="tss", choices=["tss", "sss"], help="Choose the search space."
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
default="tss",
choices=["tss", "sss"],
help="Choose the search space.",
)
args = parser.parse_args()

View File

@@ -33,9 +33,15 @@ def visualize_info(api, vis_save_dir, indicator):
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
@@ -63,8 +69,15 @@ def visualize_info(api, vis_save_dir, indicator):
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical")
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 3),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
@@ -100,7 +113,9 @@ def visualize_sss_info(api, dataset, vis_save_dir):
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False)
info = api.get_more_info(
index, "cifar10-valid", hp="90", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
@@ -272,7 +287,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False)
info = api.get_more_info(
index, "cifar10-valid", hp="200", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
@@ -297,7 +314,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
)
print("{:} collect data done.".format(time_string()))
resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"]
resnet = [
"|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
]
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
largest_indexes = [
api.query_index_by_arch(
@@ -429,9 +448,15 @@ def visualize_rank_info(api, vis_save_dir, indicator):
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
@@ -466,8 +491,17 @@ def visualize_rank_info(api, vis_save_dir, indicator):
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name))
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name))
ax.scatter(
[-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)
)
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="{:} validation".format(name),
)
ax.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
@@ -479,9 +513,13 @@ def visualize_rank_info(api, vis_save_dir, indicator):
labels = get_labels(imagenet_info)
plot_ax(labels, ax3, "ImageNet-16-120")
save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve()
save_path = (
vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve()
save_path = (
vis_save_dir / "{:}-same-relative-rank.png".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
@@ -502,9 +540,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
@@ -570,7 +614,9 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
ax1.set_title("Correlation coefficient over ALL candidates")
ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar))
ax2.set_title(
"Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)
)
save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
@@ -578,9 +624,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description="NAS-Bench-X",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log."
"--save_dir",
type=str,
default="output/vis-nas-bench",
help="Folder to save checkpoints and log.",
)
# use for train the model
args = parser.parse_args()