Add int search space
This commit is contained in:
@@ -28,15 +28,25 @@ from utils import weight_watcher
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
|
||||
parser.add_argument(
|
||||
"--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir."
|
||||
"--api_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the NAS-Bench-201 benchmark file and weight dir.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--archive_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the NAS-Bench-201 weight dir.",
|
||||
)
|
||||
parser.add_argument("--archive_path", type=str, default=None, help="The path to the NAS-Bench-201 weight dir.")
|
||||
args = parser.parse_args()
|
||||
|
||||
meta_file = Path(args.api_path)
|
||||
weight_dir = Path(args.archive_path)
|
||||
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
|
||||
assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir)
|
||||
assert (
|
||||
weight_dir.exists() and weight_dir.is_dir()
|
||||
), "invalid path for weight dir : {:}".format(weight_dir)
|
||||
|
||||
api = NASBench201API(meta_file, verbose=True)
|
||||
|
||||
@@ -46,7 +56,9 @@ if __name__ == "__main__":
|
||||
data = "cifar10" # query the info from CIFAR-10
|
||||
config = api.get_net_config(arch_index, data)
|
||||
net = get_cell_based_tiny_net(config)
|
||||
meta_info = api.query_meta_info_by_index(arch_index, hp="200") # all info about this architecture
|
||||
meta_info = api.query_meta_info_by_index(
|
||||
arch_index, hp="200"
|
||||
) # all info about this architecture
|
||||
params = meta_info.get_net_param(data, 888)
|
||||
|
||||
net.load_state_dict(params)
|
||||
|
@@ -69,7 +69,13 @@ def plot(filename):
|
||||
for xin in range(i):
|
||||
op_i = random.randint(0, len(OPS) - 1)
|
||||
# g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
|
||||
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
|
||||
g.edge(
|
||||
str(xin),
|
||||
str(i),
|
||||
label=OPS[op_i],
|
||||
color=COLORS[op_i],
|
||||
fillcolor=COLORS[op_i],
|
||||
)
|
||||
# import pdb; pdb.set_trace()
|
||||
g.render(filename, cleanup=True, view=False)
|
||||
|
||||
@@ -88,7 +94,9 @@ def test_auto_grad():
|
||||
net = Net(10)
|
||||
inputs = torch.rand(256, 10)
|
||||
loss = net(inputs)
|
||||
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
|
||||
first_order_grads = torch.autograd.grad(
|
||||
loss, net.parameters(), retain_graph=True, create_graph=True
|
||||
)
|
||||
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
|
||||
second_order_grads = []
|
||||
for grads in first_order_grads:
|
||||
@@ -108,9 +116,15 @@ def test_one_shot_model(ckpath, use_train):
|
||||
print("ckpath : {:}".format(ckpath))
|
||||
ckp = torch.load(ckpath)
|
||||
xargs = ckp["args"]
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
train_data, valid_data, xshape, class_num = get_datasets(
|
||||
xargs.dataset, xargs.data_path, -1
|
||||
)
|
||||
# config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
|
||||
config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None)
|
||||
config = load_config(
|
||||
"./configs/nas-benchmark/algos/DARTS.config",
|
||||
{"class_num": class_num, "xshape": xshape},
|
||||
None,
|
||||
)
|
||||
if xargs.dataset == "cifar10":
|
||||
cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
|
||||
xvalid_data = deepcopy(train_data)
|
||||
@@ -142,7 +156,9 @@ def test_one_shot_model(ckpath, use_train):
|
||||
search_model.load_state_dict(ckp["search_model"])
|
||||
search_model = search_model.cuda()
|
||||
api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
|
||||
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
|
||||
archs, probs, accuracies = evaluate_one_shot(
|
||||
search_model, valid_loader, api, use_train
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -53,8 +53,12 @@ def evaluate(api, weight_dir, data: str):
|
||||
# compute the weight watcher results
|
||||
config = api.get_net_config(arch_index, data)
|
||||
net = get_cell_based_tiny_net(config)
|
||||
meta_info = api.query_meta_info_by_index(arch_index, hp="200" if api.search_space_name == "topology" else "90")
|
||||
params = meta_info.get_net_param(data, 888 if api.search_space_name == "topology" else 777)
|
||||
meta_info = api.query_meta_info_by_index(
|
||||
arch_index, hp="200" if api.search_space_name == "topology" else "90"
|
||||
)
|
||||
params = meta_info.get_net_param(
|
||||
data, 888 if api.search_space_name == "topology" else 777
|
||||
)
|
||||
with torch.no_grad():
|
||||
net.load_state_dict(params)
|
||||
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||
@@ -73,7 +77,10 @@ def evaluate(api, weight_dir, data: str):
|
||||
norms.append(cur_norm)
|
||||
# query the accuracy
|
||||
info = meta_info.get_metrics(
|
||||
data, "ori-test", iepoch=None, is_random=888 if api.search_space_name == "topology" else 777
|
||||
data,
|
||||
"ori-test",
|
||||
iepoch=None,
|
||||
is_random=888 if api.search_space_name == "topology" else 777,
|
||||
)
|
||||
accuracies.append(info["accuracy"])
|
||||
del net, meta_info
|
||||
@@ -98,7 +105,11 @@ def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
|
||||
for hp in hps:
|
||||
nums = api.statistics(data, hp=hp)
|
||||
total = sum([k * v for k, v in nums.items()])
|
||||
print("Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(hp, data, total, nums))
|
||||
print(
|
||||
"Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(
|
||||
hp, data, total, nums
|
||||
)
|
||||
)
|
||||
print(time_string() + " " + "=" * 50)
|
||||
|
||||
norms, accuracies = evaluate(api, weight_dir, xdata)
|
||||
@@ -120,8 +131,15 @@ def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
|
||||
plt.xlim(min(indexes), max(indexes))
|
||||
plt.ylim(min(indexes), max(indexes))
|
||||
# plt.ylabel('y').set_rotation(30)
|
||||
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical")
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
|
||||
plt.yticks(
|
||||
np.arange(min(indexes), max(indexes), max(indexes) // 3),
|
||||
fontsize=LegendFontsize,
|
||||
rotation="vertical",
|
||||
)
|
||||
plt.xticks(
|
||||
np.arange(min(indexes), max(indexes), max(indexes) // 5),
|
||||
fontsize=LegendFontsize,
|
||||
)
|
||||
ax.scatter(indexes, labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
|
||||
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Test accuracy")
|
||||
@@ -129,7 +147,9 @@ def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
|
||||
plt.grid(zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||
ax.set_xlabel("architecture ranking sorted by the test accuracy ", fontsize=LabelSize)
|
||||
ax.set_xlabel(
|
||||
"architecture ranking sorted by the test accuracy ", fontsize=LabelSize
|
||||
)
|
||||
ax.set_ylabel("architecture ranking computed by weight watcher", fontsize=LabelSize)
|
||||
save_path = (save_dir / "{:}-{:}-test-ww.pdf".format(search_space, xdata)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
|
||||
@@ -148,9 +168,18 @@ if __name__ == "__main__":
|
||||
default="./output/vis-nas-bench/",
|
||||
help="The base-name of folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument("--search_space", type=str, default=None, choices=["tss", "sss"], help="The search space.")
|
||||
parser.add_argument(
|
||||
"--base_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir."
|
||||
"--search_space",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["tss", "sss"],
|
||||
help="The search space.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the NAS-Bench-201 benchmark file and weight dir.",
|
||||
)
|
||||
parser.add_argument("--dataset", type=str, default=None, help=".")
|
||||
args = parser.parse_args()
|
||||
@@ -160,6 +189,8 @@ if __name__ == "__main__":
|
||||
meta_file = Path(args.base_path + ".pth")
|
||||
weight_dir = Path(args.base_path + "-full")
|
||||
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
|
||||
assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir)
|
||||
assert (
|
||||
weight_dir.exists() and weight_dir.is_dir()
|
||||
), "invalid path for weight dir : {:}".format(weight_dir)
|
||||
|
||||
main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset)
|
||||
|
@@ -42,7 +42,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
|
||||
for alg, path in alg2path.items():
|
||||
data = torch.load(path)
|
||||
for index, info in data.items():
|
||||
info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])]
|
||||
info["time_w_arch"] = [
|
||||
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
|
||||
]
|
||||
for j, arch in enumerate(info["all_archs"]):
|
||||
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
|
||||
alg, search_space, dataset, index, j
|
||||
@@ -57,12 +59,16 @@ def query_performance(api, data, dataset, ticket):
|
||||
time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket))
|
||||
time_a, arch_a = time_w_arch[0]
|
||||
time_b, arch_b = time_w_arch[1]
|
||||
info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
info_a = api.get_more_info(
|
||||
arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
|
||||
)
|
||||
info_b = api.get_more_info(
|
||||
arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
|
||||
)
|
||||
accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"]
|
||||
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / (
|
||||
time_b - time_a
|
||||
) * accuracy_b
|
||||
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (
|
||||
ticket - time_a
|
||||
) / (time_b - time_a) * accuracy_b
|
||||
results.append(interplate)
|
||||
return sum(results) / len(results)
|
||||
|
||||
@@ -85,7 +91,11 @@ y_max_s = {
|
||||
("ImageNet16-120", "sss"): 46,
|
||||
}
|
||||
|
||||
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
|
||||
name2label = {
|
||||
"cifar10": "CIFAR-10",
|
||||
"cifar100": "CIFAR-100",
|
||||
"ImageNet16-120": "ImageNet-16-120",
|
||||
}
|
||||
|
||||
|
||||
def visualize_curve(api, vis_save_dir, search_space, max_time):
|
||||
@@ -100,7 +110,9 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
|
||||
alg2data = fetch_data(search_space=search_space, dataset=dataset)
|
||||
alg2accuracies = OrderedDict()
|
||||
total_tickets = 150
|
||||
time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)]
|
||||
time_tickets = [
|
||||
float(i) / total_tickets * max_time for i in range(total_tickets)
|
||||
]
|
||||
colors = ["b", "g", "c", "m", "y"]
|
||||
ax.set_xlim(0, 200)
|
||||
ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
|
||||
@@ -111,10 +123,20 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
|
||||
accuracy = query_performance(api, data, dataset, ticket)
|
||||
accuracies.append(accuracy)
|
||||
alg2accuracies[alg] = accuracies
|
||||
ax.plot([x / 100 for x in time_tickets], accuracies, c=colors[idx], label="{:}".format(alg))
|
||||
ax.plot(
|
||||
[x / 100 for x in time_tickets],
|
||||
accuracies,
|
||||
c=colors[idx],
|
||||
label="{:}".format(alg),
|
||||
)
|
||||
ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize)
|
||||
ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize)
|
||||
ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4)
|
||||
ax.set_ylabel(
|
||||
"Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize
|
||||
)
|
||||
ax.set_title(
|
||||
"Searching results on {:}".format(name2label[dataset]),
|
||||
fontsize=LabelSize + 4,
|
||||
)
|
||||
ax.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
fig, axs = plt.subplots(1, 3, figsize=figsize)
|
||||
@@ -129,12 +151,25 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NAS-Bench-X",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="output/vis-nas-bench/nas-algos",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
choices=["tss", "sss"],
|
||||
help="Choose the search space.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_time", type=float, default=20000, help="The maximum time budget."
|
||||
)
|
||||
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
|
||||
parser.add_argument("--max_time", type=float, default=20000, help="The maximum time budget.")
|
||||
args = parser.parse_args()
|
||||
|
||||
save_dir = Path(args.save_dir)
|
||||
|
@@ -29,7 +29,9 @@ from log_utils import time_string
|
||||
|
||||
|
||||
# def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'):
|
||||
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"):
|
||||
def fetch_data(
|
||||
root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"
|
||||
):
|
||||
ss_dir = "{:}-{:}".format(root_dir, search_space)
|
||||
alg2name, alg2path = OrderedDict(), OrderedDict()
|
||||
seeds = [777, 888, 999]
|
||||
@@ -45,8 +47,12 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suf
|
||||
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
|
||||
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
|
||||
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
|
||||
alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(suffix)
|
||||
alg2name["masking + Gumbel-Softmax"] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix)
|
||||
alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(
|
||||
suffix
|
||||
)
|
||||
alg2name[
|
||||
"masking + Gumbel-Softmax"
|
||||
] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix)
|
||||
alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix)
|
||||
for alg, name in alg2name.items():
|
||||
alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth")
|
||||
@@ -86,7 +92,11 @@ y_max_s = {
|
||||
("ImageNet16-120", "sss"): 46,
|
||||
}
|
||||
|
||||
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
|
||||
name2label = {
|
||||
"cifar10": "CIFAR-10",
|
||||
"cifar100": "CIFAR-100",
|
||||
"ImageNet16-120": "ImageNet-16-120",
|
||||
}
|
||||
|
||||
|
||||
def visualize_curve(api, vis_save_dir, search_space):
|
||||
@@ -111,10 +121,17 @@ def visualize_curve(api, vis_save_dir, search_space):
|
||||
try:
|
||||
structures, accs = [_[iepoch - 1] for _ in data], []
|
||||
except:
|
||||
raise ValueError("This alg {:} on {:} has invalid checkpoints.".format(alg, dataset))
|
||||
raise ValueError(
|
||||
"This alg {:} on {:} has invalid checkpoints.".format(
|
||||
alg, dataset
|
||||
)
|
||||
)
|
||||
for structure in structures:
|
||||
info = api.get_more_info(
|
||||
structure, dataset=dataset, hp=90 if api.search_space_name == "size" else 200, is_random=False
|
||||
structure,
|
||||
dataset=dataset,
|
||||
hp=90 if api.search_space_name == "size" else 200,
|
||||
is_random=False,
|
||||
)
|
||||
accs.append(info["test-accuracy"])
|
||||
accuracies.append(sum(accs) / len(accs))
|
||||
@@ -122,8 +139,13 @@ def visualize_curve(api, vis_save_dir, search_space):
|
||||
alg2accuracies[alg] = accuracies
|
||||
ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg))
|
||||
ax.set_xlabel("The searching epoch", fontsize=LabelSize)
|
||||
ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize)
|
||||
ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4)
|
||||
ax.set_ylabel(
|
||||
"Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize
|
||||
)
|
||||
ax.set_title(
|
||||
"Searching results on {:}".format(name2label[dataset]),
|
||||
fontsize=LabelSize + 4,
|
||||
)
|
||||
ax.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
fig, axs = plt.subplots(1, 3, figsize=figsize)
|
||||
@@ -138,12 +160,22 @@ def visualize_curve(api, vis_save_dir, search_space):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NAS-Bench-X",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_space", type=str, default="tss", choices=["tss", "sss"], help="Choose the search space."
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="output/vis-nas-bench/nas-algos",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
default="tss",
|
||||
choices=["tss", "sss"],
|
||||
help="Choose the search space.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@@ -33,9 +33,15 @@ def visualize_info(api, vis_save_dir, indicator):
|
||||
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar10", indicator
|
||||
)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar100", indicator
|
||||
)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"ImageNet16-120", indicator
|
||||
)
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
@@ -63,8 +69,15 @@ def visualize_info(api, vis_save_dir, indicator):
|
||||
plt.xlim(min(indexes), max(indexes))
|
||||
plt.ylim(min(indexes), max(indexes))
|
||||
# plt.ylabel('y').set_rotation(30)
|
||||
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical")
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
|
||||
plt.yticks(
|
||||
np.arange(min(indexes), max(indexes), max(indexes) // 3),
|
||||
fontsize=LegendFontsize,
|
||||
rotation="vertical",
|
||||
)
|
||||
plt.xticks(
|
||||
np.arange(min(indexes), max(indexes), max(indexes) // 5),
|
||||
fontsize=LegendFontsize,
|
||||
)
|
||||
ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
|
||||
ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
|
||||
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
|
||||
@@ -100,7 +113,9 @@ def visualize_sss_info(api, dataset, vis_save_dir):
|
||||
train_accs.append(info["train-accuracy"])
|
||||
test_accs.append(info["test-accuracy"])
|
||||
if dataset == "cifar10":
|
||||
info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False)
|
||||
info = api.get_more_info(
|
||||
index, "cifar10-valid", hp="90", is_random=False
|
||||
)
|
||||
valid_accs.append(info["valid-accuracy"])
|
||||
else:
|
||||
valid_accs.append(info["valid-accuracy"])
|
||||
@@ -272,7 +287,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
|
||||
train_accs.append(info["train-accuracy"])
|
||||
test_accs.append(info["test-accuracy"])
|
||||
if dataset == "cifar10":
|
||||
info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False)
|
||||
info = api.get_more_info(
|
||||
index, "cifar10-valid", hp="200", is_random=False
|
||||
)
|
||||
valid_accs.append(info["valid-accuracy"])
|
||||
else:
|
||||
valid_accs.append(info["valid-accuracy"])
|
||||
@@ -297,7 +314,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
|
||||
)
|
||||
print("{:} collect data done.".format(time_string()))
|
||||
|
||||
resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"]
|
||||
resnet = [
|
||||
"|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
|
||||
]
|
||||
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
|
||||
largest_indexes = [
|
||||
api.query_index_by_arch(
|
||||
@@ -429,9 +448,15 @@ def visualize_rank_info(api, vis_save_dir, indicator):
|
||||
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar10", indicator
|
||||
)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar100", indicator
|
||||
)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"ImageNet16-120", indicator
|
||||
)
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
@@ -466,8 +491,17 @@ def visualize_rank_info(api, vis_save_dir, indicator):
|
||||
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
|
||||
ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
|
||||
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name))
|
||||
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name))
|
||||
ax.scatter(
|
||||
[-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)
|
||||
)
|
||||
ax.scatter(
|
||||
[-1],
|
||||
[-1],
|
||||
marker="o",
|
||||
s=100,
|
||||
c="tab:blue",
|
||||
label="{:} validation".format(name),
|
||||
)
|
||||
ax.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
|
||||
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
|
||||
@@ -479,9 +513,13 @@ def visualize_rank_info(api, vis_save_dir, indicator):
|
||||
labels = get_labels(imagenet_info)
|
||||
plot_ax(labels, ax3, "ImageNet-16-120")
|
||||
|
||||
save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve()
|
||||
save_path = (
|
||||
vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)
|
||||
).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
|
||||
save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve()
|
||||
save_path = (
|
||||
vis_save_dir / "{:}-same-relative-rank.png".format(indicator)
|
||||
).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
|
||||
print("{:} save into {:}".format(time_string(), save_path))
|
||||
plt.close("all")
|
||||
@@ -502,9 +540,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
|
||||
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
|
||||
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar10", indicator
|
||||
)
|
||||
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"cifar100", indicator
|
||||
)
|
||||
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
|
||||
"ImageNet16-120", indicator
|
||||
)
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
@@ -570,7 +614,9 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
|
||||
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
|
||||
)
|
||||
ax1.set_title("Correlation coefficient over ALL candidates")
|
||||
ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar))
|
||||
ax2.set_title(
|
||||
"Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)
|
||||
)
|
||||
save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
|
||||
print("{:} save into {:}".format(time_string(), save_path))
|
||||
@@ -578,9 +624,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NAS-Bench-X",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log."
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="output/vis-nas-bench",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
# use for train the model
|
||||
args = parser.parse_args()
|
||||
|
Reference in New Issue
Block a user