Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -30,7 +30,11 @@ def show_time(api, epoch=12):
all_cifar10_time += cifar10_time
all_cifar100_time += cifar100_time
all_imagenet_time += imagenet_time
print("The total training time for CIFAR-10 (held-out train set) is {:} seconds".format(all_cifar10_time))
print(
"The total training time for CIFAR-10 (held-out train set) is {:} seconds".format(
all_cifar10_time
)
)
print(
"The total training time for CIFAR-100 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format(
all_cifar100_time, all_cifar100_time / all_cifar10_time

View File

@@ -30,15 +30,28 @@ from log_utils import time_string
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == "size"
if dataset == "cifar10":
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_acc = xinfo["test-accuracy"]
xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_acc = xinfo["valid-accuracy"]
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_acc = xinfo["valid-accuracy"]
test_acc = xinfo["test-accuracy"]
return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc)
return (
valid_acc,
test_acc,
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
)
def compute_kendalltau(vectori, vectorj):
@@ -61,9 +74,17 @@ if __name__ == "__main__":
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
choices=["tss", "sss"],
help="Choose the search space.",
)
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
args = parser.parse_args()
save_dir = Path(args.save_dir)
@@ -77,9 +98,17 @@ if __name__ == "__main__":
scores_1.append(valid_acc)
scores_2.append(test_acc)
correlation = compute_kendalltau(scores_1, scores_2)
print("The kendall tau correlation of {:} samples : {:}".format(len(indexes), correlation))
print(
"The kendall tau correlation of {:} samples : {:}".format(
len(indexes), correlation
)
)
correlation = compute_spearmanr(scores_1, scores_2)
print("The spearmanr correlation of {:} samples : {:}".format(len(indexes), correlation))
print(
"The spearmanr correlation of {:} samples : {:}".format(
len(indexes), correlation
)
)
# scores_1 = ['{:.2f}'.format(x) for x in scores_1]
# scores_2 = ['{:.2f}'.format(x) for x in scores_2]
# print(', '.join(scores_1))

View File

@@ -35,9 +35,15 @@ def visualize_relative_info(api, vis_save_dir, indicator):
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
@@ -65,8 +71,15 @@ def visualize_relative_info(api, vis_save_dir, indicator):
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical")
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 3),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
@@ -102,7 +115,9 @@ def visualize_sss_info(api, dataset, vis_save_dir):
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False)
info = api.get_more_info(
index, "cifar10-valid", hp="90", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
@@ -263,7 +278,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False)
info = api.get_more_info(
index, "cifar10-valid", hp="200", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
@@ -288,7 +305,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
)
print("{:} collect data done.".format(time_string()))
resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"]
resnet = [
"|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
]
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
largest_indexes = [
api.query_index_by_arch(
@@ -415,9 +434,15 @@ def visualize_rank_info(api, vis_save_dir, indicator):
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
@@ -452,8 +477,17 @@ def visualize_rank_info(api, vis_save_dir, indicator):
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name))
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name))
ax.scatter(
[-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)
)
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="{:} validation".format(name),
)
ax.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
@@ -465,9 +499,13 @@ def visualize_rank_info(api, vis_save_dir, indicator):
labels = get_labels(imagenet_info)
plot_ax(labels, ax3, "ImageNet-16-120")
save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve()
save_path = (
vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve()
save_path = (
vis_save_dir / "{:}-same-relative-rank.png".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
@@ -496,9 +534,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
@@ -564,7 +608,9 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
ax1.set_title("Correlation coefficient over ALL candidates")
ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar))
ax2.set_title(
"Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)
)
save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
@@ -572,9 +618,14 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log."
"--save_dir",
type=str,
default="output/vis-nas-bench",
help="Folder to save checkpoints and log.",
)
# use for train the model
args = parser.parse_args()

View File

@@ -43,7 +43,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])]
info["time_w_arch"] = [
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
@@ -58,12 +60,16 @@ def query_performance(api, data, dataset, ticket):
time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
info_a = api.get_more_info(
arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
info_b = api.get_more_info(
arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"]
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / (
time_b - time_a
) * accuracy_b
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (
ticket - time_a
) / (time_b - time_a) * accuracy_b
results.append(interplate)
# return sum(results) / len(results)
return np.mean(results), np.std(results)
@@ -74,12 +80,21 @@ def show_valid_test(api, data, dataset):
for i, info in data.items():
time, arch = info["time_w_arch"][-1]
if dataset == "cifar10":
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_accs.append(xinfo["test-accuracy"])
xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_accs.append(xinfo["valid-accuracy"])
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_accs.append(xinfo["valid-accuracy"])
test_accs.append(xinfo["test-accuracy"])
valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs))
@@ -114,7 +129,11 @@ x_axis_s = {
("ImageNet16-120", "sss"): 600,
}
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_curve(api, vis_save_dir, search_space):
@@ -130,10 +149,14 @@ def visualize_curve(api, vis_save_dir, search_space):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
total_tickets = 150
time_tickets = [float(i) / total_tickets * int(max_time) for i in range(total_tickets)]
time_tickets = [
float(i) / total_tickets * int(max_time) for i in range(total_tickets)
]
colors = ["b", "g", "c", "m", "y"]
ax.set_xlim(0, x_axis_s[(xdataset, search_space)])
ax.set_ylim(y_min_s[(xdataset, search_space)], y_max_s[(xdataset, search_space)])
ax.set_ylim(
y_min_s[(xdataset, search_space)], y_max_s[(xdataset, search_space)]
)
for idx, (alg, data) in enumerate(alg2data.items()):
accuracies = []
for ticket in time_tickets:
@@ -142,13 +165,25 @@ def visualize_curve(api, vis_save_dir, search_space):
valid_str, test_str = show_valid_test(api, data, xdataset)
# print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std))
print(
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(time_string(), alg, valid_str, test_str)
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(
time_string(), alg, valid_str, test_str
)
)
alg2accuracies[alg] = accuracies
ax.plot([x / 100 for x in time_tickets], accuracies, c=colors[idx], label="{:}".format(alg))
ax.plot(
[x / 100 for x in time_tickets],
accuracies,
c=colors[idx],
label="{:}".format(alg),
)
ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize)
ax.set_ylabel("Test accuracy on {:}".format(name2label[xdataset]), fontsize=LabelSize)
ax.set_title("Searching results on {:}".format(name2label[xdataset]), fontsize=LabelSize + 4)
ax.set_ylabel(
"Test accuracy on {:}".format(name2label[xdataset]), fontsize=LabelSize
)
ax.set_title(
"Searching results on {:}".format(name2label[xdataset]),
fontsize=LabelSize + 4,
)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
@@ -174,9 +209,17 @@ if __name__ == "__main__":
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
choices=["tss", "sss"],
help="Choose the search space.",
)
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
args = parser.parse_args()
save_dir = Path(args.save_dir)

View File

@@ -31,18 +31,33 @@ from log_utils import time_string
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == "size"
if dataset == "cifar10":
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_acc = xinfo["test-accuracy"]
xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_acc = xinfo["valid-accuracy"]
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_acc = xinfo["valid-accuracy"]
test_acc = xinfo["test-accuracy"]
return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc)
return (
valid_acc,
test_acc,
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
)
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"):
def fetch_data(
root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"
):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999]
@@ -55,8 +70,12 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suf
alg2name["ENAS"] = "enas-affine0_BN0-None"
alg2name["SETN"] = "setn-affine0_BN0-None"
else:
alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["masking + Gumbel-Softmax"] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(
suffix
)
alg2name[
"masking + Gumbel-Softmax"
] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth")
@@ -72,7 +91,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suf
continue
data = torch.load(xpath, map_location=torch.device("cpu"))
try:
data = torch.load(data["last_checkpoint"], map_location=torch.device("cpu"))
data = torch.load(
data["last_checkpoint"], map_location=torch.device("cpu")
)
except:
xpath = str(data["last_checkpoint"]).split("E100-")
if len(xpath) == 2 and os.path.isfile(xpath[0] + xpath[1]):
@@ -82,7 +103,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suf
elif "tunas" in str(data["last_checkpoint"]):
xpath = str(data["last_checkpoint"]).replace("tunas", "mask_rl")
else:
raise ValueError("Invalid path: {:}".format(data["last_checkpoint"]))
raise ValueError(
"Invalid path: {:}".format(data["last_checkpoint"])
)
data = torch.load(xpath, map_location=torch.device("cpu"))
alg2data[alg].append(data["genotypes"])
print("This algorithm : {:} has {:} valid ckps.".format(alg, ok_num))
@@ -108,9 +131,18 @@ y_max_s = {
("ImageNet16-120", "sss"): 46,
}
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
name2suffix = {("sss", "warm"): "-WARM0.3", ("sss", "none"): "-WARMNone", ("tss", "none"): None, ("tss", None): None}
name2suffix = {
("sss", "warm"): "-WARM0.3",
("sss", "none"): "-WARMNone",
("tss", "none"): None,
("tss", None): None,
}
def visualize_curve(api, vis_save_dir, search_space, suffix):
@@ -123,7 +155,11 @@ def visualize_curve(api, vis_save_dir, search_space, suffix):
def sub_plot_fn(ax, dataset):
print("{:} plot {:10s}".format(time_string(), dataset))
alg2data = fetch_data(search_space=search_space, dataset=dataset, suffix=name2suffix[(search_space, suffix)])
alg2data = fetch_data(
search_space=search_space,
dataset=dataset,
suffix=name2suffix[(search_space, suffix)],
)
alg2accuracies = OrderedDict()
epochs = 100
colors = ["b", "g", "c", "m", "y", "r"]
@@ -135,10 +171,17 @@ def visualize_curve(api, vis_save_dir, search_space, suffix):
try:
structures, accs = [_[iepoch - 1] for _ in data], []
except:
raise ValueError("This alg {:} on {:} has invalid checkpoints.".format(alg, dataset))
raise ValueError(
"This alg {:} on {:} has invalid checkpoints.".format(
alg, dataset
)
)
for structure in structures:
info = api.get_more_info(
structure, dataset=dataset, hp=90 if api.search_space_name == "size" else 200, is_random=False
structure,
dataset=dataset,
hp=90 if api.search_space_name == "size" else 200,
is_random=False,
)
accs.append(info["test-accuracy"])
accuracies.append(sum(accs) / len(accs))
@@ -146,17 +189,31 @@ def visualize_curve(api, vis_save_dir, search_space, suffix):
alg2accuracies[alg] = accuracies
ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg))
ax.set_xlabel("The searching epoch", fontsize=LabelSize)
ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize)
ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4)
ax.set_ylabel(
"Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize
)
ax.set_title(
"Searching results on {:}".format(name2label[dataset]),
fontsize=LabelSize + 4,
)
structures, valid_accs, test_accs = [_[epochs - 1] for _ in data], [], []
print("{:} plot alg : {:} -- final {:} architectures.".format(time_string(), alg, len(structures)))
print(
"{:} plot alg : {:} -- final {:} architectures.".format(
time_string(), alg, len(structures)
)
)
for arch in structures:
valid_acc, test_acc, _ = get_valid_test_acc(api, arch, dataset)
test_accs.append(test_acc)
valid_accs.append(valid_acc)
print(
"{:} plot alg : {:} -- validation: {:.2f}$\pm${:.2f} -- test: {:.2f}$\pm${:.2f}".format(
time_string(), alg, np.mean(valid_accs), np.std(valid_accs), np.mean(test_accs), np.std(test_accs)
time_string(),
alg,
np.mean(valid_accs),
np.std(valid_accs),
np.mean(test_accs),
np.std(test_accs),
)
)
ax.legend(loc=4, fontsize=LegendFontsize)
@@ -166,16 +223,23 @@ def visualize_curve(api, vis_save_dir, search_space, suffix):
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print("sub-plot {:} on {:} done.".format(dataset, search_space))
save_path = (vis_save_dir / "{:}-ws-{:}-curve.png".format(search_space, suffix)).resolve()
save_path = (
vis_save_dir / "{:}-ws-{:}-curve.png".format(search_space, suffix)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
args = parser.parse_args()

View File

@@ -28,7 +28,9 @@ from nats_bench import create
from log_utils import time_string
plt.rcParams.update({"text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"]})
plt.rcParams.update(
{"text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"]}
)
## for Palatino and other serif fonts use:
plt.rcParams.update(
{
@@ -57,16 +59,22 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
raise ValueError("Unkonwn search space: {:}".format(search_space))
alg2all[r"REA ($\mathcal{H}^{0}$)"] = dict(
path=os.path.join(ss_dir, dataset + suffixes[0], "R-EA-SS3", "results.pth"), color="b", linestyle="-"
path=os.path.join(ss_dir, dataset + suffixes[0], "R-EA-SS3", "results.pth"),
color="b",
linestyle="-",
)
alg2all[r"REA ({:})".format(hp)] = dict(
path=os.path.join(ss_dir, dataset + suffixes[1], "R-EA-SS3", "results.pth"), color="b", linestyle="--"
path=os.path.join(ss_dir, dataset + suffixes[1], "R-EA-SS3", "results.pth"),
color="b",
linestyle="--",
)
for alg, xdata in alg2all.items():
data = torch.load(xdata["path"])
for index, info in data.items():
info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])]
info["time_w_arch"] = [
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
@@ -81,12 +89,16 @@ def query_performance(api, data, dataset, ticket):
time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
info_a = api.get_more_info(
arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
info_b = api.get_more_info(
arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"]
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / (
time_b - time_a
) * accuracy_b
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (
ticket - time_a
) / (time_b - time_a) * accuracy_b
results.append(interplate)
# return sum(results) / len(results)
return np.mean(results), np.std(results)
@@ -119,7 +131,11 @@ x_axis_s = {
("ImageNet16-120", "sss"): 600,
}
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
spaces2latex = {
"tss": r"$\mathcal{S}_{t}$",
@@ -149,7 +165,9 @@ def visualize_curve(api_dict, vis_save_dir):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
total_tickets = 200
time_tickets = [float(i) / total_tickets * int(max_time) for i in range(total_tickets)]
time_tickets = [
float(i) / total_tickets * int(max_time) for i in range(total_tickets)
]
ax.set_xlim(0, x_axis_s[(dataset, search_space)])
ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for tick in ax.get_xticklabels():
@@ -162,16 +180,29 @@ def visualize_curve(api_dict, vis_save_dir):
accuracies = []
for ticket in time_tickets:
# import pdb; pdb.set_trace()
accuracy, accuracy_std = query_performance(api_dict[search_space], xdata["data"], dataset, ticket)
accuracy, accuracy_std = query_performance(
api_dict[search_space], xdata["data"], dataset, ticket
)
accuracies.append(accuracy)
# print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std))
print("{:} plot alg : {:10s} on {:}".format(time_string(), alg, search_space))
print(
"{:} plot alg : {:10s} on {:}".format(time_string(), alg, search_space)
)
alg2accuracies[alg] = accuracies
ax.plot(time_tickets, accuracies, c=xdata["color"], linestyle=xdata["linestyle"], label="{:}".format(alg))
ax.plot(
time_tickets,
accuracies,
c=xdata["color"],
linestyle=xdata["linestyle"],
label="{:}".format(alg),
)
ax.set_xlabel("Estimated wall-clock time", fontsize=LabelSize)
ax.set_ylabel("Test accuracy", fontsize=LabelSize)
ax.set_title(
r"Results on {:} over {:}".format(name2label[dataset], spaces2latex[search_space]), fontsize=LabelSize
r"Results on {:} over {:}".format(
name2label[dataset], spaces2latex[search_space]
),
fontsize=LabelSize,
)
ax.legend(loc=4, fontsize=LegendFontsize)

View File

@@ -30,12 +30,20 @@ from models import get_cell_based_tiny_net
from nats_bench import create
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} with top-{:} information".format(time_string(), search_space, topk))
print(
"{:} start to visualize {:} with top-{:} information".format(
time_string(), search_space, topk
)
)
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "cache-{:}-info.pth".format(search_space)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
@@ -46,8 +54,12 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
all_info = OrderedDict()
for dataset in datasets:
info_less = api.get_more_info(index, dataset, hp="12", is_random=False)
info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False)
all_info[dataset] = dict(less=info_less["test-accuracy"], more=info_more["test-accuracy"])
info_more = api.get_more_info(
index, dataset, hp=api.full_train_epochs, is_random=False
)
all_info[dataset] = dict(
less=info_less["test-accuracy"], more=info_more["test-accuracy"]
)
all_infos[index] = all_info
torch.save(all_infos, cache_file_path)
print("{:} save all cache data into {:}".format(time_string(), cache_file_path))
@@ -80,12 +92,18 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
for idx in selected_indexes:
standard_scores.append(
api.get_more_info(
idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=False
idx,
dataset,
hp=api.full_train_epochs if indicator == "more" else "12",
is_random=False,
)["test-accuracy"]
)
random_scores.append(
api.get_more_info(
idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=True
idx,
dataset,
hp=api.full_train_epochs if indicator == "more" else "12",
is_random=True,
)["test-accuracy"]
)
indexes = list(range(len(selected_indexes)))
@@ -105,11 +123,28 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, random_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Average Over Multi-Trials")
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="Randomly Selected Trial")
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="Average Over Multi-Trials",
)
ax.scatter(
[-1],
[-1],
marker="^",
s=100,
c="tab:green",
label="Randomly Selected Trial",
)
coef, p = scipy.stats.kendalltau(standard_scores, random_scores)
ax.set_xlabel("architecture ranking in {:}".format(name2label[dataset]), fontsize=LabelSize)
ax.set_xlabel(
"architecture ranking in {:}".format(name2label[dataset]),
fontsize=LabelSize,
)
if dataset == "cifar10":
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
ax.legend(loc=4, fontsize=LegendFontsize)
@@ -117,17 +152,27 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
for dataset, ax in zip(datasets, axs):
rank_coef = sub_plot_fn(ax, dataset, indicator)
print("sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(dataset, search_space, rank_coef))
print(
"sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(
dataset, search_space, rank_coef
)
)
save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)).resolve()
save_path = (
vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)).resolve()
save_path = (
vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("Save into {:}".format(save_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--save_dir",
type=str,

View File

@@ -42,7 +42,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])]
info["time_w_arch"] = [
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
@@ -54,15 +56,28 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == "size"
if dataset == "cifar10":
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_acc = xinfo["test-accuracy"]
xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_acc = xinfo["valid-accuracy"]
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_acc = xinfo["valid-accuracy"]
test_acc = xinfo["test-accuracy"]
return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc)
return (
valid_acc,
test_acc,
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
)
def show_valid_test(api, arch):
@@ -84,8 +99,16 @@ def find_best_valid(api, dataset):
best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]
print("-" * 50 + "{:10s}".format(dataset) + "-" * 50)
print("Best ({:}) architecture on validation: {:}".format(best_valid_index, api[best_valid_index]))
print("Best ({:}) architecture on test: {:}".format(best_test_index, api[best_test_index]))
print(
"Best ({:}) architecture on validation: {:}".format(
best_valid_index, api[best_valid_index]
)
)
print(
"Best ({:}) architecture on test: {:}".format(
best_test_index, api[best_test_index]
)
)
_, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)
print("using validation ::: {:}".format(perf_str))
_, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)
@@ -130,10 +153,14 @@ def show_multi_trial(search_space):
v_acc, t_acc = query_performance(api, x, xdataset, float(max_time))
valid_accs.append(v_acc)
test_accs.append(t_acc)
valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs))
valid_str = "{:.2f}$\pm${:.2f}".format(
np.mean(valid_accs), np.std(valid_accs)
)
test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs))
print(
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(time_string(), alg, valid_str, test_str)
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(
time_string(), alg, valid_str, test_str
)
)
if search_space == "tss":

View File

@@ -51,23 +51,35 @@ def evaluate_all_datasets(
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configuration
if dataset == "cifar10" or dataset == "cifar100":
split_info = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
split_info = load_config(
"configs/nas-benchmark/cifar-split.txt", None, None
)
elif dataset.startswith("ImageNet16"):
split_info = load_config("configs/nas-benchmark/{:}-split.txt".format(dataset), None, None)
split_info = load_config(
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
)
else:
raise ValueError("invalid dataset : {:}".format(dataset))
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
config = load_config(
config_path, dict(class_num=class_num, xshape=xshape), logger
)
# check whether use the splitted validation set
if bool(split):
assert dataset == "cifar10"
ValLoaders = {
"ori-test": torch.utils.data.DataLoader(
valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
}
assert len(train_data) == len(split_info.train) + len(
split_info.valid
), "invalid length : {:} vs {:} + {:}".format(len(train_data), len(split_info.train), len(split_info.valid))
), "invalid length : {:} vs {:} + {:}".format(
len(train_data), len(split_info.train), len(split_info.valid)
)
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
@@ -90,47 +102,67 @@ def evaluate_all_datasets(
else:
# data loader
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
train_data,
batch_size=config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
if dataset == "cifar10":
ValLoaders = {"ori-test": valid_loader}
elif dataset == "cifar100":
cifar100_splits = load_config("configs/nas-benchmark/cifar100-test-split.txt", None, None)
cifar100_splits = load_config(
"configs/nas-benchmark/cifar100-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
}
elif dataset == "ImageNet16-120":
imagenet16_splits = load_config("configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None)
imagenet16_splits = load_config(
"configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
@@ -143,19 +175,36 @@ def evaluate_all_datasets(
dataset_key = dataset_key + "-valid"
logger.log(
"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size
dataset_key,
len(train_data),
len(valid_data),
len(train_loader),
len(valid_loader),
config.batch_size,
)
)
logger.log("Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config))
logger.log(
"Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)
)
for key, value in ValLoaders.items():
logger.log("Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)))
logger.log(
"Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))
)
# arch-index= 9930, arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|
# this genotype is the architecture with the highest accuracy on CIFAR-100 validation set
genotype = "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|"
arch_config = dict2config(
dict(name="infer.shape.tiny", channels=channels, genotype=genotype, num_classes=class_num), None
dict(
name="infer.shape.tiny",
channels=channels,
genotype=genotype,
num_classes=class_num,
),
None,
)
results = bench_evaluate_for_seed(
arch_config, config, train_loader, ValLoaders, seed, logger
)
results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger)
all_infos[dataset_key] = results
all_dataset_keys.append(dataset_key)
all_infos["all_dataset_keys"] = all_dataset_keys
@@ -183,8 +232,12 @@ def main(
logger.log("xargs : cover_mode = {:}".format(cover_mode))
logger.log("-" * 100)
logger.log(
"Start evaluating range =: {:06d} - {:06d}".format(min(to_evaluate_indexes), max(to_evaluate_indexes))
+ "({:} in total) / {:06d} with cover-mode={:}".format(len(to_evaluate_indexes), len(nets), cover_mode)
"Start evaluating range =: {:06d} - {:06d}".format(
min(to_evaluate_indexes), max(to_evaluate_indexes)
)
+ "({:} in total) / {:06d} with cover-mode={:}".format(
len(to_evaluate_indexes), len(nets), cover_mode
)
)
for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
logger.log(
@@ -199,7 +252,13 @@ def main(
channelstr = nets[index]
logger.log(
"\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format(
time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, "-" * 15
time_string(),
i,
len(to_evaluate_indexes),
index,
len(nets),
seeds,
"-" * 15,
)
)
logger.log("{:} {:} {:}".format("-" * 15, channelstr, "-" * 15))
@@ -210,17 +269,33 @@ def main(
to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed)
if to_save_name.exists():
if cover_mode:
logger.log("Find existing file : {:}, remove it before evaluation".format(to_save_name))
logger.log(
"Find existing file : {:}, remove it before evaluation".format(
to_save_name
)
)
os.remove(str(to_save_name))
else:
logger.log("Find existing file : {:}, skip this evaluation".format(to_save_name))
logger.log(
"Find existing file : {:}, skip this evaluation".format(
to_save_name
)
)
has_continue = True
continue
results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger)
results = evaluate_all_datasets(
channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger
)
torch.save(results, to_save_name)
logger.log(
"\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format(
time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, to_save_name
time_string(),
i,
len(to_evaluate_indexes),
index,
len(nets),
seeds,
to_save_name,
)
)
# measure elapsed time
@@ -230,7 +305,9 @@ def main(
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True)
)
logger.log("This arch costs : {:}".format(convert_secs2time(epoch_time.val, True)))
logger.log(
"This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))
)
logger.log("{:}".format("*" * 100))
logger.log(
"{:} {:74s} {:}".format(
@@ -277,16 +354,24 @@ def filter_indexes(xlist, mode, save_dir, seeds):
SLURM_PROCID, SLURM_NTASKS = "SLURM_PROCID", "SLURM_NTASKS"
if SLURM_PROCID in os.environ and SLURM_NTASKS in os.environ: # run on the slurm
proc_id, ntasks = int(os.environ[SLURM_PROCID]), int(os.environ[SLURM_NTASKS])
assert 0 <= proc_id < ntasks, "invalid proc_id {:} vs ntasks {:}".format(proc_id, ntasks)
scales = [int(float(i) / ntasks * len(all_indexes)) for i in range(ntasks)] + [len(all_indexes)]
assert 0 <= proc_id < ntasks, "invalid proc_id {:} vs ntasks {:}".format(
proc_id, ntasks
)
scales = [int(float(i) / ntasks * len(all_indexes)) for i in range(ntasks)] + [
len(all_indexes)
]
per_job = []
for i in range(ntasks):
xs, xe = min(max(scales[i], 0), len(all_indexes) - 1), min(max(scales[i + 1] - 1, 0), len(all_indexes) - 1)
xs, xe = min(max(scales[i], 0), len(all_indexes) - 1), min(
max(scales[i + 1] - 1, 0), len(all_indexes) - 1
)
per_job.append((xs, xe))
for i, srange in enumerate(per_job):
print(" -->> {:2d}/{:02d} : {:}".format(i, ntasks, srange))
current_range = per_job[proc_id]
all_indexes = [all_indexes[i] for i in range(current_range[0], current_range[1] + 1)]
all_indexes = [
all_indexes[i] for i in range(current_range[0], current_range[1] + 1)
]
# set the device id
device = proc_id % torch.cuda.device_count()
torch.cuda.set_device(device)
@@ -301,30 +386,67 @@ def filter_indexes(xlist, mode, save_dir, seeds):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (size search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NATS-Bench (size search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--mode", type=str, required=True, choices=["new", "cover"], help="The script mode.")
parser.add_argument(
"--save_dir", type=str, default="output/NATS-Bench-size", help="Folder to save checkpoints and log."
"--mode",
type=str,
required=True,
choices=["new", "cover"],
help="The script mode.",
)
parser.add_argument(
"--save_dir",
type=str,
default="output/NATS-Bench-size",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--candidateC",
type=int,
nargs="+",
default=[8, 16, 24, 32, 40, 48, 56, 64],
help=".",
)
parser.add_argument(
"--num_layers", type=int, default=5, help="The number of layers in a network."
)
parser.add_argument("--candidateC", type=int, nargs="+", default=[8, 16, 24, 32, 40, 48, 56, 64], help=".")
parser.add_argument("--num_layers", type=int, default=5, help="The number of layers in a network.")
parser.add_argument("--check_N", type=int, default=32768, help="For safety.")
# use for train the model
parser.add_argument("--workers", type=int, default=8, help="The number of data loading workers (default: 2)")
parser.add_argument("--srange", type=str, required=True, help="The range of models to be evaluated")
parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.")
parser.add_argument("--xpaths", type=str, nargs="+", help="The root path for this dataset.")
parser.add_argument("--splits", type=int, nargs="+", help="The root path for this dataset.")
parser.add_argument(
"--hyper", type=str, default="12", choices=["01", "12", "90"], help="The tag for hyper-parameters."
"--workers",
type=int,
default=8,
help="The number of data loading workers (default: 2)",
)
parser.add_argument(
"--srange", type=str, required=True, help="The range of models to be evaluated"
)
parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.")
parser.add_argument(
"--xpaths", type=str, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--splits", type=int, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--hyper",
type=str,
default="12",
choices=["01", "12", "90"],
help="The tag for hyper-parameters.",
)
parser.add_argument(
"--seeds", type=int, nargs="+", help="The range of models to be evaluated"
)
parser.add_argument("--seeds", type=int, nargs="+", help="The range of models to be evaluated")
args = parser.parse_args()
nets = traverse_net(args.candidateC, args.num_layers)
if len(nets) != args.check_N:
raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N))
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper)
if not os.path.isfile(opt_config):
@@ -337,12 +459,16 @@ if __name__ == "__main__":
raise ValueError("invalid length of seeds args: {:}".format(args.seeds))
if not (len(args.datasets) == len(args.xpaths) == len(args.splits)):
raise ValueError(
"invalid infos : {:} vs {:} vs {:}".format(len(args.datasets), len(args.xpaths), len(args.splits))
"invalid infos : {:} vs {:} vs {:}".format(
len(args.datasets), len(args.xpaths), len(args.splits)
)
)
if args.workers <= 0:
raise ValueError("invalid number of workers : {:}".format(args.workers))
target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds)
target_indexes = filter_indexes(
to_evaluate_indexes, args.mode, save_dir, args.seeds
)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True

View File

@@ -57,23 +57,35 @@ def evaluate_all_datasets(
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configuration
if dataset == "cifar10" or dataset == "cifar100":
split_info = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
split_info = load_config(
"configs/nas-benchmark/cifar-split.txt", None, None
)
elif dataset.startswith("ImageNet16"):
split_info = load_config("configs/nas-benchmark/{:}-split.txt".format(dataset), None, None)
split_info = load_config(
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
)
else:
raise ValueError("invalid dataset : {:}".format(dataset))
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
config = load_config(
config_path, dict(class_num=class_num, xshape=xshape), logger
)
# check whether use splited validation set
if bool(split):
assert dataset == "cifar10"
ValLoaders = {
"ori-test": torch.utils.data.DataLoader(
valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
}
assert len(train_data) == len(split_info.train) + len(
split_info.valid
), "invalid length : {:} vs {:} + {:}".format(len(train_data), len(split_info.train), len(split_info.valid))
), "invalid length : {:} vs {:} + {:}".format(
len(train_data), len(split_info.train), len(split_info.valid)
)
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
@@ -96,47 +108,67 @@ def evaluate_all_datasets(
else:
# data loader
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
train_data,
batch_size=config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
if dataset == "cifar10":
ValLoaders = {"ori-test": valid_loader}
elif dataset == "cifar100":
cifar100_splits = load_config("configs/nas-benchmark/cifar100-test-split.txt", None, None)
cifar100_splits = load_config(
"configs/nas-benchmark/cifar100-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
}
elif dataset == "ImageNet16-120":
imagenet16_splits = load_config("configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None)
imagenet16_splits = load_config(
"configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
@@ -149,12 +181,21 @@ def evaluate_all_datasets(
dataset_key = dataset_key + "-valid"
logger.log(
"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size
dataset_key,
len(train_data),
len(valid_data),
len(train_loader),
len(valid_loader),
config.batch_size,
)
)
logger.log("Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config))
logger.log(
"Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)
)
for key, value in ValLoaders.items():
logger.log("Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)))
logger.log(
"Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))
)
arch_config = dict2config(
dict(
name="infer.tiny",
@@ -165,7 +206,9 @@ def evaluate_all_datasets(
),
None,
)
results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger)
results = bench_evaluate_for_seed(
arch_config, config, train_loader, ValLoaders, seed, logger
)
all_infos[dataset_key] = results
all_dataset_keys.append(dataset_key)
all_infos["all_dataset_keys"] = all_dataset_keys
@@ -194,8 +237,12 @@ def main(
logger.log("xargs : cover_mode = {:}".format(cover_mode))
logger.log("-" * 100)
logger.log(
"Start evaluating range =: {:06d} - {:06d}".format(min(to_evaluate_indexes), max(to_evaluate_indexes))
+ "({:} in total) / {:06d} with cover-mode={:}".format(len(to_evaluate_indexes), len(nets), cover_mode)
"Start evaluating range =: {:06d} - {:06d}".format(
min(to_evaluate_indexes), max(to_evaluate_indexes)
)
+ "({:} in total) / {:06d} with cover-mode={:}".format(
len(to_evaluate_indexes), len(nets), cover_mode
)
)
for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
logger.log(
@@ -210,7 +257,13 @@ def main(
arch = nets[index]
logger.log(
"\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format(
time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, "-" * 15
time_string(),
i,
len(to_evaluate_indexes),
index,
len(nets),
seeds,
"-" * 15,
)
)
logger.log("{:} {:} {:}".format("-" * 15, arch, "-" * 15))
@@ -221,10 +274,18 @@ def main(
to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed)
if to_save_name.exists():
if cover_mode:
logger.log("Find existing file : {:}, remove it before evaluation".format(to_save_name))
logger.log(
"Find existing file : {:}, remove it before evaluation".format(
to_save_name
)
)
os.remove(str(to_save_name))
else:
logger.log("Find existing file : {:}, skip this evaluation".format(to_save_name))
logger.log(
"Find existing file : {:}, skip this evaluation".format(
to_save_name
)
)
has_continue = True
continue
results = evaluate_all_datasets(
@@ -241,7 +302,13 @@ def main(
torch.save(results, to_save_name)
logger.log(
"\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format(
time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, to_save_name
time_string(),
i,
len(to_evaluate_indexes),
index,
len(nets),
seeds,
to_save_name,
)
)
# measure elapsed time
@@ -251,7 +318,9 @@ def main(
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True)
)
logger.log("This arch costs : {:}".format(convert_secs2time(epoch_time.val, True)))
logger.log(
"This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))
)
logger.log("{:}".format("*" * 100))
logger.log(
"{:} {:74s} {:}".format(
@@ -267,7 +336,9 @@ def main(
logger.close()
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config):
def train_single_model(
save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config
):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
@@ -278,19 +349,32 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se
Path(save_dir)
/ "specifics"
/ "{:}-{:}-{:}-{:}".format(
"LESS" if use_less else "FULL", model_str, arch_config["channel"], arch_config["num_cells"]
"LESS" if use_less else "FULL",
model_str,
arch_config["channel"],
arch_config["num_cells"],
)
)
logger = Logger(str(save_dir), 0, False)
if model_str in CellArchitectures:
arch = CellArchitectures[model_str]
logger.log("The model string is found in pre-defined architecture dict : {:}".format(model_str))
logger.log(
"The model string is found in pre-defined architecture dict : {:}".format(
model_str
)
)
else:
try:
arch = CellStructure.str2structure(model_str)
except:
raise ValueError("Invalid model string : {:}. It can not be found or parsed.".format(model_str))
assert arch.check_valid_op(get_search_spaces("cell", "full")), "{:} has the invalid op.".format(arch)
raise ValueError(
"Invalid model string : {:}. It can not be found or parsed.".format(
model_str
)
)
assert arch.check_valid_op(
get_search_spaces("cell", "full")
), "{:} has the invalid op.".format(arch)
logger.log("Start train-evaluate {:}".format(arch.tostr()))
logger.log("arch_config : {:}".format(arch_config))
@@ -303,27 +387,55 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se
)
to_save_name = save_dir / "seed-{:04d}.pth".format(seed)
if to_save_name.exists():
logger.log("Find the existing file {:}, directly load!".format(to_save_name))
logger.log(
"Find the existing file {:}, directly load!".format(to_save_name)
)
checkpoint = torch.load(to_save_name)
else:
logger.log("Does not find the existing file {:}, train and evaluate!".format(to_save_name))
logger.log(
"Does not find the existing file {:}, train and evaluate!".format(
to_save_name
)
)
checkpoint = evaluate_all_datasets(
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
arch,
datasets,
xpaths,
splits,
use_less,
seed,
arch_config,
workers,
logger,
)
torch.save(checkpoint, to_save_name)
# log information
logger.log("{:}".format(checkpoint["info"]))
all_dataset_keys = checkpoint["all_dataset_keys"]
for dataset_key in all_dataset_keys:
logger.log("\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15))
logger.log(
"\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)
)
dataset_info = checkpoint[dataset_key]
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
logger.log("Flops = {:} MB, Params = {:} MB".format(dataset_info["flop"], dataset_info["param"]))
logger.log(
"Flops = {:} MB, Params = {:} MB".format(
dataset_info["flop"], dataset_info["param"]
)
)
logger.log("config : {:}".format(dataset_info["config"]))
logger.log("Training State (finish) = {:}".format(dataset_info["finish-train"]))
logger.log(
"Training State (finish) = {:}".format(dataset_info["finish-train"])
)
last_epoch = dataset_info["total_epoch"] - 1
train_acc1es, train_acc5es = dataset_info["train_acc1es"], dataset_info["train_acc5es"]
valid_acc1es, valid_acc5es = dataset_info["valid_acc1es"], dataset_info["valid_acc5es"]
train_acc1es, train_acc5es = (
dataset_info["train_acc1es"],
dataset_info["train_acc5es"],
)
valid_acc1es, valid_acc5es = (
dataset_info["valid_acc1es"],
dataset_info["valid_acc5es"],
)
logger.log(
"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format(
train_acc1es[last_epoch],
@@ -337,7 +449,9 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se
# measure elapsed time
seed_time.update(time.time() - start_time)
start_time = time.time()
need_time = "Time Left: {:}".format(convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True))
need_time = "Time Left: {:}".format(
convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)
)
logger.log(
"\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format(
_is, len(seeds), seed, need_time
@@ -349,7 +463,11 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se
def generate_meta_info(save_dir, max_node, divide=40):
aa_nas_bench_ss = get_search_spaces("cell", "nas-bench-201")
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)))
print(
"There are {:} archs vs {:}.".format(
len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)
)
)
random.seed(88) # please do not change this line for reproducibility
random.shuffle(archs)
@@ -361,10 +479,12 @@ def generate_meta_info(save_dir, max_node, divide=40):
== "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|"
), "please check the 0-th architecture : {:}".format(archs[0])
assert (
archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
archs[9].tostr()
== "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
), "please check the 9-th architecture : {:}".format(archs[9])
assert (
archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
archs[123].tostr()
== "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
), "please check the 123-th architecture : {:}".format(archs[123])
total_arch = len(archs)
@@ -383,11 +503,21 @@ def generate_meta_info(save_dir, max_node, divide=40):
and valid_split[10] == 18
and valid_split[111] == 242
), "{:} {:} {:} - {:} {:} {:}".format(
train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111]
train_split[0],
train_split[10],
train_split[111],
valid_split[0],
valid_split[10],
valid_split[111],
)
splits = {num: {"train": train_split, "valid": valid_split}}
info = {"archs": [x.tostr() for x in archs], "total": total_arch, "max_node": max_node, "splits": splits}
info = {
"archs": [x.tostr() for x in archs],
"total": total_arch,
"max_node": max_node,
"splits": splits,
}
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
@@ -400,7 +530,11 @@ def generate_meta_info(save_dir, max_node, divide=40):
def traverse_net(max_node):
aa_nas_bench_ss = get_search_spaces("cell", "nats-bench")
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)))
print(
"There are {:} archs vs {:}.".format(
len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)
)
)
random.seed(88) # please do not change this line for reproducibility
random.shuffle(archs)
@@ -409,10 +543,12 @@ def traverse_net(max_node):
== "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|"
), "please check the 0-th architecture : {:}".format(archs[0])
assert (
archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
archs[9].tostr()
== "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
), "please check the 9-th architecture : {:}".format(archs[9])
assert (
archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
archs[123].tostr()
== "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
), "please check the 123-th architecture : {:}".format(archs[123])
return [x.tostr() for x in archs]
@@ -439,32 +575,62 @@ def filter_indexes(xlist, mode, save_dir, seeds):
if __name__ == "__main__":
# mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()]
parser = argparse.ArgumentParser(
description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NATS-Bench (topology search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--mode", type=str, required=True, help="The script mode.")
parser.add_argument(
"--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log."
"--save_dir",
type=str,
default="output/NATS-Bench-topology",
help="Folder to save checkpoints and log.",
)
parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell (please do not change it).")
# use for train the model
parser.add_argument("--workers", type=int, default=8, help="number of data loading workers (default: 2)")
parser.add_argument("--srange", type=str, required=True, help="The range of models to be evaluated")
parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.")
parser.add_argument("--xpaths", type=str, nargs="+", help="The root path for this dataset.")
parser.add_argument("--splits", type=int, nargs="+", help="The root path for this dataset.")
parser.add_argument(
"--hyper", type=str, default="12", choices=["01", "12", "200"], help="The tag for hyper-parameters."
"--max_node",
type=int,
default=4,
help="The maximum node in a cell (please do not change it).",
)
# use for train the model
parser.add_argument(
"--workers",
type=int,
default=8,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--srange", type=str, required=True, help="The range of models to be evaluated"
)
parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.")
parser.add_argument(
"--xpaths", type=str, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--splits", type=int, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--hyper",
type=str,
default="12",
choices=["01", "12", "200"],
help="The tag for hyper-parameters.",
)
parser.add_argument("--seeds", type=int, nargs="+", help="The range of models to be evaluated")
parser.add_argument("--channel", type=int, default=16, help="The number of channels.")
parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.")
parser.add_argument(
"--seeds", type=int, nargs="+", help="The range of models to be evaluated"
)
parser.add_argument(
"--channel", type=int, default=16, help="The number of channels."
)
parser.add_argument(
"--num_cells", type=int, default=5, help="The number of cells in one stage."
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
args = parser.parse_args()
assert args.mode in ["meta", "new", "cover"] or args.mode.startswith("specific-"), "invalid mode : {:}".format(
args.mode
)
assert args.mode in ["meta", "new", "cover"] or args.mode.startswith(
"specific-"
), "invalid mode : {:}".format(args.mode)
if args.mode == "meta":
generate_meta_info(args.save_dir, args.max_node)
@@ -485,7 +651,9 @@ if __name__ == "__main__":
else:
nets = traverse_net(args.max_node)
if len(nets) != args.check_N:
raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N))
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper)
if not os.path.isfile(opt_config):
raise ValueError("{:} is not a file.".format(opt_config))
@@ -496,12 +664,16 @@ if __name__ == "__main__":
raise ValueError("invalid length of seeds args: {:}".format(args.seeds))
if not (len(args.datasets) == len(args.xpaths) == len(args.splits)):
raise ValueError(
"invalid infos : {:} vs {:} vs {:}".format(len(args.datasets), len(args.xpaths), len(args.splits))
"invalid infos : {:} vs {:} vs {:}".format(
len(args.datasets), len(args.xpaths), len(args.splits)
)
)
if args.workers <= 0:
raise ValueError("invalid number of workers : {:}".format(args.workers))
target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds)
target_indexes = filter_indexes(
to_evaluate_indexes, args.mode, save_dir, args.seeds
)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
@@ -519,5 +691,9 @@ if __name__ == "__main__":
opt_config,
target_indexes,
args.mode == "cover",
{"name": "infer.tiny", "channel": args.channel, "num_cells": args.num_cells},
{
"name": "infer.tiny",
"channel": args.channel,
"num_cells": args.num_cells,
},
)

View File

@@ -31,24 +31,34 @@ from utils import get_md5_file
NATS_SSS_BASE_NAME = "NATS-sss-v1_0" # 2020.08.28
def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]) -> ArchResults:
def account_one_arch(
arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]
) -> ArchResults:
information = ArchResults(arch_index, arch_str)
for checkpoint_path in checkpoints:
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
except:
raise ValueError("This checkpoint failed to be loaded : {:}".format(checkpoint_path))
raise ValueError(
"This checkpoint failed to be loaded : {:}".format(checkpoint_path)
)
used_seed = checkpoint_path.name.split("-")[-1].split(".")[0]
ok_dataset = 0
for dataset in datasets:
if dataset not in checkpoint:
print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path))
print(
"Can not find {:} in arch-{:} from {:}".format(
dataset, arch_index, checkpoint_path
)
)
continue
else:
ok_dataset += 1
results = checkpoint[dataset]
assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
assert results[
"finish-train"
], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
arch_index, used_seed, dataset, checkpoint_path
)
arch_config = {
@@ -71,13 +81,20 @@ def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], d
None,
)
xresult.update_train_info(
results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"]
results["train_acc1es"],
results["train_acc5es"],
results["train_losses"],
results["train_times"],
)
xresult.update_eval(
results["valid_acc1es"], results["valid_losses"], results["valid_times"]
)
xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"])
information.update(dataset, int(used_seed), xresult)
if ok_dataset < len(datasets):
raise ValueError(
"{:} does find enought data : {:} vs {:}".format(checkpoint_path, ok_dataset, len(datasets))
"{:} does find enought data : {:} vs {:}".format(
checkpoint_path, ok_dataset, len(datasets)
)
)
return information
@@ -107,7 +124,9 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]):
arch_info.reset_latency("ImageNet16-120", None, image_latency)
# CIFAR10 VALID
train_per_epoch_time = list(hp2info["01"].query("cifar10-valid", 777).train_times.values())
train_per_epoch_time = list(
hp2info["01"].query("cifar10-valid", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time = [], []
for key, value in hp2info["01"].query("cifar10-valid", 777).eval_times.items():
@@ -121,11 +140,17 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]):
eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time)
for hp, arch_info in hp2info.items():
arch_info.reset_pseudo_train_times("cifar10-valid", None, train_per_epoch_time)
arch_info.reset_pseudo_eval_times("cifar10-valid", None, "x-valid", eval_x_valid_time)
arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_ori_test_time)
arch_info.reset_pseudo_eval_times(
"cifar10-valid", None, "x-valid", eval_x_valid_time
)
arch_info.reset_pseudo_eval_times(
"cifar10-valid", None, "ori-test", eval_ori_test_time
)
# CIFAR10
train_per_epoch_time = list(hp2info["01"].query("cifar10", 777).train_times.values())
train_per_epoch_time = list(
hp2info["01"].query("cifar10", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time = []
for key, value in hp2info["01"].query("cifar10", 777).eval_times.items():
@@ -136,10 +161,14 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]):
eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time)
for hp, arch_info in hp2info.items():
arch_info.reset_pseudo_train_times("cifar10", None, train_per_epoch_time)
arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_ori_test_time)
arch_info.reset_pseudo_eval_times(
"cifar10", None, "ori-test", eval_ori_test_time
)
# CIFAR100
train_per_epoch_time = list(hp2info["01"].query("cifar100", 777).train_times.values())
train_per_epoch_time = list(
hp2info["01"].query("cifar100", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], []
for key, value in hp2info["01"].query("cifar100", 777).eval_times.items():
@@ -156,12 +185,18 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]):
eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time)
for hp, arch_info in hp2info.items():
arch_info.reset_pseudo_train_times("cifar100", None, train_per_epoch_time)
arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_x_valid_time)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "x-valid", eval_x_valid_time
)
arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_x_test_time)
arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_ori_test_time)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "ori-test", eval_ori_test_time
)
# ImageNet16-120
train_per_epoch_time = list(hp2info["01"].query("ImageNet16-120", 777).train_times.values())
train_per_epoch_time = list(
hp2info["01"].query("ImageNet16-120", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], []
for key, value in hp2info["01"].query("ImageNet16-120", 777).eval_times.items():
@@ -178,9 +213,15 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]):
eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time)
for hp, arch_info in hp2info.items():
arch_info.reset_pseudo_train_times("ImageNet16-120", None, train_per_epoch_time)
arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "x-valid", eval_x_valid_time)
arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "x-test", eval_x_test_time)
arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "ori-test", eval_ori_test_time)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "x-valid", eval_x_valid_time
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "x-test", eval_x_test_time
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "ori-test", eval_ori_test_time
)
return hp2info
@@ -200,7 +241,9 @@ def simplify(save_dir, save_name, nets, total):
seeds.add(seed)
nums.append(len(xlist))
print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist)))
assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total)
assert (
len(nets) == total == max(nums)
), "there are some missed files : {:} vs {:}".format(max(nums), total)
print("{:} start simplify the checkpoint.".format(time_string()))
datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")
@@ -225,7 +268,10 @@ def simplify(save_dir, save_name, nets, total):
for hp in hps:
sub_save_dir = save_dir / "raw-data-{:}".format(hp)
ckps = [sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) for seed in seeds]
ckps = [
sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed)
for seed in seeds
]
ckps = [x for x in ckps if x.exists()]
if len(ckps) == 0:
raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp))
@@ -238,21 +284,31 @@ def simplify(save_dir, save_name, nets, total):
hp2info["01"].clear_params() # to save some spaces...
to_save_data = OrderedDict(
{"01": hp2info["01"].state_dict(), "12": hp2info["12"].state_dict(), "90": hp2info["90"].state_dict()}
{
"01": hp2info["01"].state_dict(),
"12": hp2info["12"].state_dict(),
"90": hp2info["90"].state_dict(),
}
)
pickle_save(to_save_data, str(full_save_path))
for hp in hps:
hp2info[hp].clear_params()
to_save_data = OrderedDict(
{"01": hp2info["01"].state_dict(), "12": hp2info["12"].state_dict(), "90": hp2info["90"].state_dict()}
{
"01": hp2info["01"].state_dict(),
"12": hp2info["12"].state_dict(),
"90": hp2info["90"].state_dict(),
}
)
pickle_save(to_save_data, str(simple_save_path))
arch2infos[index] = to_save_data
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True))
need_time = "{:}".format(
convert_secs2time(arch_time.avg * (total - index - 1), True)
)
# print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
print("{:} {:} done.".format(time_string(), save_name))
final_infos = {
@@ -297,7 +353,8 @@ def traverse_net(candidates: List[int], N: int):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (size search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NATS-Bench (size search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--base_save_dir",
@@ -305,15 +362,27 @@ if __name__ == "__main__":
default="./output/NATS-Bench-size",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--candidateC", type=int, nargs="+", default=[8, 16, 24, 32, 40, 48, 56, 64], help=".")
parser.add_argument("--num_layers", type=int, default=5, help="The number of layers in a network.")
parser.add_argument(
"--candidateC",
type=int,
nargs="+",
default=[8, 16, 24, 32, 40, 48, 56, 64],
help=".",
)
parser.add_argument(
"--num_layers", type=int, default=5, help="The number of layers in a network."
)
parser.add_argument("--check_N", type=int, default=32768, help="For safety.")
parser.add_argument("--save_name", type=str, default="process", help="The save directory.")
parser.add_argument(
"--save_name", type=str, default="process", help="The save directory."
)
args = parser.parse_args()
nets = traverse_net(args.candidateC, args.num_layers)
if len(nets) != args.check_N:
raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N))
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
save_dir = Path(args.base_save_dir)
simplify(save_dir, args.save_name, nets, args.check_N)

View File

@@ -54,7 +54,11 @@ def copy_data(source_dir, target_dir, meta_path):
target_path = os.path.join(target_dir, file_name)
if os.path.exists(source_path):
s2t[source_path] = target_path
print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t)))
print(
"Map from {:} to {:}, find {:} missed ckps.".format(
source_dir, target_dir, len(s2t)
)
)
for s, t in s2t.items():
copyfile(s, t)
@@ -64,9 +68,18 @@ if __name__ == "__main__":
description="NATS-Bench (size search space) file manager.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.")
parser.add_argument(
"--save_dir", type=str, default="output/NATS-Bench-size", help="Folder to save checkpoints and log."
"--mode",
type=str,
required=True,
choices=["check", "copy"],
help="The script mode.",
)
parser.add_argument(
"--save_dir",
type=str,
default="output/NATS-Bench-size",
help="Folder to save checkpoints and log.",
)
parser.add_argument("--check_N", type=int, default=32768, help="For safety.")
# use for train the model
@@ -76,7 +89,10 @@ if __name__ == "__main__":
for config in possible_configs:
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N)
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config))
torch.save(
dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps),
"{:}/meta-{:}.pth".format(args.save_dir, config),
)
elif args.mode == "copy":
for config in possible_configs:
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)

View File

@@ -91,14 +91,22 @@ if __name__ == "__main__":
for fast_mode in [True, False]:
for verbose in [True, False]:
api_nats_tss = create(None, "tss", fast_mode=fast_mode, verbose=True)
print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose))
print(
"{:} create with fast_mode={:} and verbose={:}".format(
time_string(), fast_mode, verbose
)
)
test_api(api_nats_tss, False)
del api_nats_tss
gc.collect()
for fast_mode in [True, False]:
for verbose in [True, False]:
print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose))
print(
"{:} create with fast_mode={:} and verbose={:}".format(
time_string(), fast_mode, verbose
)
)
api_nats_sss = create(None, "size", fast_mode=fast_mode, verbose=True)
print("{:} --->>> {:}".format(time_string(), api_nats_sss))
test_api(api_nats_sss, True)

View File

@@ -50,7 +50,9 @@ def simplify(save_dir, save_name, nets, total, sup_config):
seeds.add(seed)
nums.append(len(xlist))
print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist)))
assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total)
assert (
len(nets) == total == max(nums)
), "there are some missed files : {:} vs {:}".format(max(nums), total)
print("{:} start simplify the checkpoint.".format(time_string()))
datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")
@@ -78,7 +80,9 @@ def simplify(save_dir, save_name, nets, total, sup_config):
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True))
need_time = "{:}".format(
convert_secs2time(arch_time.avg * (total - index - 1), True)
)
# print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
print("{:} {:} done.".format(time_string(), save_name))
final_infos = {
@@ -108,7 +112,11 @@ def simplify(save_dir, save_name, nets, total, sup_config):
def traverse_net(max_node):
aa_nas_bench_ss = get_search_spaces("cell", "nats-bench")
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)))
print(
"There are {:} archs vs {:}.".format(
len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)
)
)
random.seed(88) # please do not change this line for reproducibility
random.shuffle(archs)
@@ -117,10 +125,12 @@ def traverse_net(max_node):
== "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|"
), "please check the 0-th architecture : {:}".format(archs[0])
assert (
archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
archs[9].tostr()
== "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
), "please check the 9-th architecture : {:}".format(archs[9])
assert (
archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
archs[123].tostr()
== "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
), "please check the 123-th architecture : {:}".format(archs[123])
return [x.tostr() for x in archs]
@@ -128,7 +138,8 @@ def traverse_net(max_node):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NATS-Bench (topology search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--base_save_dir",
@@ -136,16 +147,26 @@ if __name__ == "__main__":
default="./output/NATS-Bench-topology",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.")
parser.add_argument("--channel", type=int, default=16, help="The number of channels.")
parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.")
parser.add_argument(
"--max_node", type=int, default=4, help="The maximum node in a cell."
)
parser.add_argument(
"--channel", type=int, default=16, help="The number of channels."
)
parser.add_argument(
"--num_cells", type=int, default=5, help="The number of cells in one stage."
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
parser.add_argument("--save_name", type=str, default="process", help="The save directory.")
parser.add_argument(
"--save_name", type=str, default="process", help="The save directory."
)
args = parser.parse_args()
nets = traverse_net(args.max_node)
if len(nets) != args.check_N:
raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N))
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
save_dir = Path(args.base_save_dir)
simplify(

View File

@@ -32,7 +32,9 @@ from utils import get_md5_file
from nas_201_api import NASBench201API
api = NASBench201API("{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"]))
api = NASBench201API(
"{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"])
)
NATS_TSS_BASE_NAME = "NATS-tss-v1_0" # 2020.08.28
@@ -68,35 +70,58 @@ def create_result_count(
)
if "train_times" in results: # new version
xresult.update_train_info(
results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"]
results["train_acc1es"],
results["train_acc5es"],
results["train_losses"],
results["train_times"],
)
xresult.update_eval(
results["valid_acc1es"], results["valid_losses"], results["valid_times"]
)
xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"])
else:
network = get_cell_based_tiny_net(net_config)
network.load_state_dict(xresult.get_net_param())
if dataset == "cifar10-valid":
xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"])
xresult.update_OLD_eval(
"x-valid", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda()
)
xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss})
xresult.update_OLD_eval(
"ori-test",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
xresult.update_latency(latencies)
elif dataset == "cifar10":
xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"])
xresult.update_OLD_eval(
"ori-test", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()
)
xresult.update_latency(latencies)
elif dataset == "cifar100" or dataset == "ImageNet16-120":
xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"])
xresult.update_OLD_eval(
"ori-test", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda()
)
xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss})
xresult.update_OLD_eval(
"x-valid",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()
)
xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss})
xresult.update_OLD_eval(
"x-test",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
xresult.update_latency(latencies)
else:
raise ValueError("invalid dataset name : {:}".format(dataset))
@@ -112,12 +137,18 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic
ok_dataset = 0
for dataset in datasets:
if dataset not in checkpoint:
print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path))
print(
"Can not find {:} in arch-{:} from {:}".format(
dataset, arch_index, checkpoint_path
)
)
continue
else:
ok_dataset += 1
results = checkpoint[dataset]
assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
assert results[
"finish-train"
], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
arch_index, used_seed, dataset, checkpoint_path
)
arch_config = {
@@ -127,7 +158,9 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic
"class_num": results["config"]["class_num"],
}
xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict)
xresult = create_result_count(
used_seed, dataset, arch_config, results, dataloader_dict
)
information.update(dataset, int(used_seed), xresult)
if ok_dataset == 0:
raise ValueError("{:} does not find any data".format(checkpoint_path))
@@ -137,7 +170,8 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic
def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResults]):
# calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
cifar010_latency = (
api.get_latency(arch_index, "cifar10-valid", hp="200") + api.get_latency(arch_index, "cifar10", hp="200")
api.get_latency(arch_index, "cifar10-valid", hp="200")
+ api.get_latency(arch_index, "cifar10", hp="200")
) / 2
cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200")
image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200")
@@ -147,7 +181,9 @@ def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResult
arch_info.reset_latency("cifar100", None, cifar100_latency)
arch_info.reset_latency("ImageNet16-120", None, image_latency)
train_per_epoch_time = list(arch_infos["12"].query("cifar10-valid", 777).train_times.values())
train_per_epoch_time = list(
arch_infos["12"].query("cifar10-valid", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time = [], []
for key, value in arch_infos["12"].query("cifar10-valid", 777).eval_times.items():
@@ -157,7 +193,9 @@ def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResult
eval_x_valid_time.append(value)
else:
raise ValueError("-- {:} --".format(key))
eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time))
eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(
np.mean(eval_x_valid_time)
)
nums = {
"ImageNet16-120-train": 151700,
"ImageNet16-120-valid": 3000,
@@ -170,36 +208,72 @@ def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResult
"cifar100-test": 10000,
"cifar100-valid": 5000,
}
eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums["cifar10-valid-valid"] + nums["cifar10-test"])
eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (
nums["cifar10-valid-valid"] + nums["cifar10-test"]
)
for hp, arch_info in arch_infos.items():
arch_info.reset_pseudo_train_times(
"cifar10-valid", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-valid-train"]
"cifar10-valid",
None,
train_per_epoch_time
/ nums["cifar10-valid-train"]
* nums["cifar10-valid-train"],
)
arch_info.reset_pseudo_train_times(
"cifar10", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"]
"cifar10",
None,
train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"],
)
arch_info.reset_pseudo_train_times(
"cifar100", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"]
"cifar100",
None,
train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"],
)
arch_info.reset_pseudo_train_times(
"ImageNet16-120", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["ImageNet16-120-train"]
"ImageNet16-120",
None,
train_per_epoch_time
/ nums["cifar10-valid-train"]
* nums["ImageNet16-120-train"],
)
arch_info.reset_pseudo_eval_times(
"cifar10-valid", None, "x-valid", eval_per_sample * nums["cifar10-valid-valid"]
)
arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"])
arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"])
arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"])
arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"])
arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"])
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "x-valid", eval_per_sample * nums["ImageNet16-120-valid"]
"cifar10-valid",
None,
"x-valid",
eval_per_sample * nums["cifar10-valid-valid"],
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "x-test", eval_per_sample * nums["ImageNet16-120-valid"]
"cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"]
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "ori-test", eval_per_sample * nums["ImageNet16-120-test"]
"cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"]
)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"]
)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"]
)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"]
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120",
None,
"x-valid",
eval_per_sample * nums["ImageNet16-120-valid"],
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120",
None,
"x-test",
eval_per_sample * nums["ImageNet16-120-valid"],
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120",
None,
"ori-test",
eval_per_sample * nums["ImageNet16-120-test"],
)
return arch_infos
@@ -220,7 +294,9 @@ def simplify(save_dir, save_name, nets, total, sup_config):
seeds.add(seed)
nums.append(len(xlist))
print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist)))
assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total)
assert (
len(nets) == total == max(nums)
), "there are some missed files : {:} vs {:}".format(max(nums), total)
print("{:} start simplify the checkpoint.".format(time_string()))
datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")
@@ -236,7 +312,12 @@ def simplify(save_dir, save_name, nets, total, sup_config):
arch2infos, evaluated_indexes = dict(), set()
end_time, arch_time = time.time(), AverageMeter()
# save the meta information
temp_final_infos = {"meta_archs": nets, "total_archs": total, "arch2infos": None, "evaluated_indexes": set()}
temp_final_infos = {
"meta_archs": nets,
"total_archs": total,
"arch2infos": None,
"evaluated_indexes": set(),
}
pickle_save(temp_final_infos, str(full_save_dir / "meta.pickle"))
pickle_save(temp_final_infos, str(simple_save_dir / "meta.pickle"))
@@ -248,29 +329,40 @@ def simplify(save_dir, save_name, nets, total, sup_config):
simple_save_path = simple_save_dir / "{:06d}.pickle".format(index)
for hp in hps:
sub_save_dir = save_dir / "raw-data-{:}".format(hp)
ckps = [sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) for seed in seeds]
ckps = [
sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed)
for seed in seeds
]
ckps = [x for x in ckps if x.exists()]
if len(ckps) == 0:
raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp))
arch_info = account_one_arch(index, arch_str, ckps, datasets, dataloader_dict)
arch_info = account_one_arch(
index, arch_str, ckps, datasets, dataloader_dict
)
hp2info[hp] = arch_info
hp2info = correct_time_related_info(index, hp2info)
evaluated_indexes.add(index)
to_save_data = OrderedDict({"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()})
to_save_data = OrderedDict(
{"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}
)
pickle_save(to_save_data, str(full_save_path))
for hp in hps:
hp2info[hp].clear_params()
to_save_data = OrderedDict({"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()})
to_save_data = OrderedDict(
{"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}
)
pickle_save(to_save_data, str(simple_save_path))
arch2infos[index] = to_save_data
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True))
need_time = "{:}".format(
convert_secs2time(arch_time.avg * (total - index - 1), True)
)
# print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
print("{:} {:} done.".format(time_string(), save_name))
final_infos = {
@@ -303,7 +395,11 @@ def simplify(save_dir, save_name, nets, total, sup_config):
def traverse_net(max_node):
aa_nas_bench_ss = get_search_spaces("cell", "nats-bench")
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)))
print(
"There are {:} archs vs {:}.".format(
len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)
)
)
random.seed(88) # please do not change this line for reproducibility
random.shuffle(archs)
@@ -312,10 +408,12 @@ def traverse_net(max_node):
== "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|"
), "please check the 0-th architecture : {:}".format(archs[0])
assert (
archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
archs[9].tostr()
== "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
), "please check the 9-th architecture : {:}".format(archs[9])
assert (
archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
archs[123].tostr()
== "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
), "please check the 123-th architecture : {:}".format(archs[123])
return [x.tostr() for x in archs]
@@ -323,7 +421,8 @@ def traverse_net(max_node):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NATS-Bench (topology search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--base_save_dir",
@@ -331,16 +430,26 @@ if __name__ == "__main__":
default="./output/NATS-Bench-topology",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.")
parser.add_argument("--channel", type=int, default=16, help="The number of channels.")
parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.")
parser.add_argument(
"--max_node", type=int, default=4, help="The maximum node in a cell."
)
parser.add_argument(
"--channel", type=int, default=16, help="The number of channels."
)
parser.add_argument(
"--num_cells", type=int, default=5, help="The number of cells in one stage."
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
parser.add_argument("--save_name", type=str, default="process", help="The save directory.")
parser.add_argument(
"--save_name", type=str, default="process", help="The save directory."
)
args = parser.parse_args()
nets = traverse_net(args.max_node)
if len(nets) != args.check_N:
raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N))
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
save_dir = Path(args.base_save_dir)
simplify(

View File

@@ -53,7 +53,11 @@ def copy_data(source_dir, target_dir, meta_path):
target_path = os.path.join(target_dir, file_name)
if os.path.exists(source_path):
s2t[source_path] = target_path
print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t)))
print(
"Map from {:} to {:}, find {:} missed ckps.".format(
source_dir, target_dir, len(s2t)
)
)
for s, t in s2t.items():
copyfile(s, t)
@@ -63,9 +67,18 @@ if __name__ == "__main__":
description="NATS-Bench (topology search space) file manager.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.")
parser.add_argument(
"--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log."
"--mode",
type=str,
required=True,
choices=["check", "copy"],
help="The script mode.",
)
parser.add_argument(
"--save_dir",
type=str,
default="output/NATS-Bench-topology",
help="Folder to save checkpoints and log.",
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
# use for train the model
@@ -75,8 +88,13 @@ if __name__ == "__main__":
if args.mode == "check":
for config, possible_seeds in zip(possible_configs, possible_seedss):
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds)
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config))
seed2ckps, miss2ckps = obtain_valid_ckp(
cur_save_dir, args.check_N, possible_seeds
)
torch.save(
dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps),
"{:}/meta-{:}.pth".format(args.save_dir, config),
)
elif args.mode == "copy":
for config in possible_configs:
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)