add autodl

This commit is contained in:
mhz
2024-08-25 18:02:31 +02:00
parent 192f286cfb
commit a0a25f291c
431 changed files with 50646 additions and 8 deletions

View File

@@ -0,0 +1,53 @@
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
##############################################################################
# python ./exps/NATS-Bench/Analyze-time.py #
##############################################################################
import os, sys, time, tqdm, argparse
from pathlib import Path
from xautodl.config_utils import dict2config, load_config
from xautodl.datasets import get_datasets
from nats_bench import create
def show_time(api, epoch=12):
print("Show the time for {:} with {:}-epoch-training".format(api, epoch))
all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0
for index in tqdm.tqdm(range(len(api))):
info = api.get_more_info(index, "ImageNet16-120", hp=epoch)
imagenet_time = info["train-all-time"]
info = api.get_more_info(index, "cifar10-valid", hp=epoch)
cifar10_time = info["train-all-time"]
info = api.get_more_info(index, "cifar100", hp=epoch)
cifar100_time = info["train-all-time"]
# accumulate the time
all_cifar10_time += cifar10_time
all_cifar100_time += cifar100_time
all_imagenet_time += imagenet_time
print(
"The total training time for CIFAR-10 (held-out train set) is {:} seconds".format(
all_cifar10_time
)
)
print(
"The total training time for CIFAR-100 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format(
all_cifar100_time, all_cifar100_time / all_cifar10_time
)
)
print(
"The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format(
all_imagenet_time, all_imagenet_time / all_cifar10_time
)
)
if __name__ == "__main__":
api_nats_tss = create(None, "tss", fast_mode=True, verbose=False)
show_time(api_nats_tss, 12)
api_nats_sss = create(None, "sss", fast_mode=True, verbose=False)
show_time(api_nats_sss, 12)

View File

@@ -0,0 +1,123 @@
###############################################################
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-correlations.py #
###############################################################
import os, gc, sys, time, scipy, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from nats_bench import create
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == "size"
if dataset == "cifar10":
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_acc = xinfo["test-accuracy"]
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_acc = xinfo["valid-accuracy"]
else:
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_acc = xinfo["valid-accuracy"]
test_acc = xinfo["test-accuracy"]
return (
valid_acc,
test_acc,
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
)
def compute_kendalltau(vectori, vectorj):
# indexes = list(range(len(vectori)))
# rank_1 = sorted(indexes, key=lambda i: vectori[i])
# rank_2 = sorted(indexes, key=lambda i: vectorj[i])
# import pdb; pdb.set_trace()
coef, p = scipy.stats.kendalltau(vectori, vectorj)
return coef
def compute_spearmanr(vectori, vectorj):
coef, p = scipy.stats.spearmanr(vectori, vectorj)
return coef
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
choices=["tss", "sss"],
help="Choose the search space.",
)
args = parser.parse_args()
save_dir = Path(args.save_dir)
api = create(None, "tss", fast_mode=True, verbose=False)
indexes = list(range(1, 10000, 300))
scores_1 = []
scores_2 = []
for index in indexes:
valid_acc, test_acc, _ = get_valid_test_acc(api, index, "cifar10")
scores_1.append(valid_acc)
scores_2.append(test_acc)
correlation = compute_kendalltau(scores_1, scores_2)
print(
"The kendall tau correlation of {:} samples : {:}".format(
len(indexes), correlation
)
)
correlation = compute_spearmanr(scores_1, scores_2)
print(
"The spearmanr correlation of {:} samples : {:}".format(
len(indexes), correlation
)
)
# scores_1 = ['{:.2f}'.format(x) for x in scores_1]
# scores_2 = ['{:.2f}'.format(x) for x in scores_2]
# print(', '.join(scores_1))
# print(', '.join(scores_2))
dpi, width, height = 250, 1000, 1000
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.scatter(scores_1, scores_2, marker="^", s=0.5, c="tab:green", alpha=0.8)
save_path = "/Users/xuanyidong/Desktop/test-temp-rank.png"
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")

View File

@@ -0,0 +1,651 @@
###############################################################
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
# The code to draw Figure 2 / 3 / 4 / 5 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-fig2_5.py #
###############################################################
import os, sys, time, torch, argparse
import scipy
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from xautodl.models import get_cell_based_tiny_net
from nats_bench import create
def visualize_relative_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print("{:} start to visualize relative ranking".format(time_string()))
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append(cifar100_ord_indexes.index(idx))
imagenet_labels.append(imagenet_ord_indexes.index(idx))
print("{:} prepare data done.".format(time_string()))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 3),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10")
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100")
ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120")
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
def visualize_sss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} information".format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset)
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp="90")
params.append(cost_info["params"])
flops.append(cost_info["flops"])
# accuracy
info = api.get_more_info(index, dataset, hp="90", is_random=False)
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(
index, "cifar10-valid", hp="90", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
info = {
"params": params,
"flops": flops,
"train_accs": train_accs,
"valid_accs": valid_accs,
"test_accs": test_accs,
}
torch.save(info, cache_file_path)
else:
print("Find cache file : {:}".format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = (
info["params"],
info["flops"],
info["train_accs"],
info["valid_accs"],
info["test_accs"],
)
print("{:} collect data done.".format(time_string()))
# pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64']
pyramid = ["8:16:24:32:40", "8:16:32:48:64", "32:40:48:56:64"]
pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid]
largest_indexes = [api.query_index_by_arch("64:64:64:64:64")]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
# ax1, ax2, ax3, ax4, ax5 = axs
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax1, ax2, ax3, ax4 = axs
ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
ax1.scatter(
[params[x] for x in pyramid_indexes],
[train_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax1.scatter(
[params[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax1.legend(loc=4, fontsize=LegendFontsize)
ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
ax2.scatter(
[flops[x] for x in pyramid_indexes],
[train_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax2.scatter(
[flops[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
# ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
ax3.scatter(
[params[x] for x in pyramid_indexes],
[test_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax3.scatter(
[params[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
ax4.scatter(
[flops[x] for x in pyramid_indexes],
[test_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax4.scatter(
[flops[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
# ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / "sss-{:}.png".format(dataset.lower())
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def visualize_tss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} information".format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset)
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp="12")
params.append(cost_info["params"])
flops.append(cost_info["flops"])
# accuracy
info = api.get_more_info(index, dataset, hp="200", is_random=False)
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(
index, "cifar10-valid", hp="200", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
print("")
info = {
"params": params,
"flops": flops,
"train_accs": train_accs,
"valid_accs": valid_accs,
"test_accs": test_accs,
}
torch.save(info, cache_file_path)
else:
print("Find cache file : {:}".format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = (
info["params"],
info["flops"],
info["train_accs"],
info["valid_accs"],
info["test_accs"],
)
print("{:} collect data done.".format(time_string()))
resnet = [
"|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
]
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
largest_indexes = [
api.query_index_by_arch(
"|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|"
)
]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax1, ax2, ax3, ax4 = axs
ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
ax1.scatter(
[params[x] for x in resnet_indexes],
[train_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax1.scatter(
[params[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax1.legend(loc=4, fontsize=LegendFontsize)
ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
ax2.scatter(
[flops[x] for x in resnet_indexes],
[train_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax2.scatter(
[flops[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
# ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
ax3.scatter(
[params[x] for x in resnet_indexes],
[test_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax3.scatter(
[params[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
ax4.scatter(
[flops[x] for x in resnet_indexes],
[test_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax4.scatter(
[flops[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
# ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / "tss-{:}.png".format(dataset.lower())
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def visualize_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print("{:} start to visualize relative ranking".format(time_string()))
dpi, width, height = 250, 3800, 1200
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 3, figsize=figsize)
ax1, ax2, ax3 = axs
def get_labels(info):
ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i])
ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i])
labels = []
for idx in ord_test_indexes:
labels.append(ord_valid_indexes.index(idx))
return labels
def plot_ax(labels, ax, name):
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
tick.label.set_rotation(90)
ax.set_xlim(min(indexes), max(indexes))
ax.set_ylim(min(indexes), max(indexes))
ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3))
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter(
[-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)
)
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="{:} validation".format(name),
)
ax.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
labels = get_labels(cifar010_info)
plot_ax(labels, ax1, "CIFAR-10")
labels = get_labels(cifar100_info)
plot_ax(labels, ax2, "CIFAR-100")
labels = get_labels(imagenet_info)
plot_ax(labels, ax3, "ImageNet-16-120")
save_path = (
vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (
vis_save_dir / "{:}-same-relative-rank.png".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def compute_kendalltau(vectori, vectorj):
# indexes = list(range(len(vectori)))
# rank_1 = sorted(indexes, key=lambda i: vectori[i])
# rank_2 = sorted(indexes, key=lambda i: vectorj[i])
return scipy.stats.kendalltau(vectori, vectorj).correlation
def calculate_correlation(*vectors):
matrix = []
for i, vectori in enumerate(vectors):
x = []
for j, vectorj in enumerate(vectors):
# x.append(np.corrcoef(vectori, vectorj)[0,1])
x.append(compute_kendalltau(vectori, vectorj))
matrix.append(x)
return np.array(matrix)
def visualize_all_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print("{:} start to visualize relative ranking".format(time_string()))
dpi, width, height = 250, 3200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 2, figsize=figsize)
ax1, ax2 = axs
sns_size, xformat = 15, ".2f"
CoRelMatrix = calculate_correlation(
cifar010_info["valid_accs"],
cifar010_info["test_accs"],
cifar100_info["valid_accs"],
cifar100_info["test_accs"],
imagenet_info["valid_accs"],
imagenet_info["test_accs"],
)
sns.heatmap(
CoRelMatrix,
annot=True,
annot_kws={"size": sns_size},
fmt=xformat,
linewidths=0.5,
ax=ax1,
xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
selected_indexes, acc_bar = [], 92
for i, acc in enumerate(cifar010_info["test_accs"]):
if acc > acc_bar:
selected_indexes.append(i)
cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes]
cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes]
cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes]
cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes]
imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes]
imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes]
CoRelMatrix = calculate_correlation(
cifar010_valid_accs,
cifar010_test_accs,
cifar100_valid_accs,
cifar100_test_accs,
imagenet_valid_accs,
imagenet_test_accs,
)
sns.heatmap(
CoRelMatrix,
annot=True,
annot_kws={"size": sns_size},
fmt=xformat,
linewidths=0.5,
ax=ax2,
xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
ax1.set_title("Correlation coefficient over ALL candidates")
ax2.set_title(
"Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)
)
save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench",
help="Folder to save checkpoints and log.",
)
# use for train the model
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
# Figure 3 (a-c)
api_tss = create(None, "tss", verbose=True)
for xdata in datasets:
visualize_tss_info(api_tss, xdata, to_save_dir)
# Figure 3 (d-f)
api_sss = create(None, "size", verbose=True)
for xdata in datasets:
visualize_sss_info(api_sss, xdata, to_save_dir)
# Figure 2
visualize_relative_info(None, to_save_dir, "tss")
visualize_relative_info(None, to_save_dir, "sss")
# Figure 4
visualize_rank_info(None, to_save_dir, "tss")
visualize_rank_info(None, to_save_dir, "sss")
# Figure 5
visualize_all_rank_info(None, to_save_dir, "tss")
visualize_all_rank_info(None, to_save_dir, "sss")

View File

@@ -0,0 +1,225 @@
###############################################################
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
# The code to draw Figure 6 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-fig6.py --search_space tss
# Usage: python exps/NATS-Bench/draw-fig6.py --search_space sss
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from nats_bench import create
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name["REA"] = "R-EA-SS3"
alg2name["REINFORCE"] = "REINFORCE-0.01"
alg2name["RANDOM"] = "RANDOM"
alg2name["BOHB"] = "BOHB"
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth")
assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg])
alg2data = OrderedDict()
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info["time_w_arch"] = [
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
)
alg2data[alg] = data
return alg2data
def query_performance(api, data, dataset, ticket):
results, is_size_space = [], api.search_space_name == "size"
for i, info in data.items():
time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(
arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
info_b = api.get_more_info(
arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"]
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (
ticket - time_a
) / (time_b - time_a) * accuracy_b
results.append(interplate)
# return sum(results) / len(results)
return np.mean(results), np.std(results)
def show_valid_test(api, data, dataset):
valid_accs, test_accs, is_size_space = [], [], api.search_space_name == "size"
for i, info in data.items():
time, arch = info["time_w_arch"][-1]
if dataset == "cifar10":
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_accs.append(xinfo["test-accuracy"])
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_accs.append(xinfo["valid-accuracy"])
else:
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_accs.append(xinfo["valid-accuracy"])
test_accs.append(xinfo["test-accuracy"])
valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs))
test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs))
return valid_str, test_str
y_min_s = {
("cifar10", "tss"): 90,
("cifar10", "sss"): 92,
("cifar100", "tss"): 65,
("cifar100", "sss"): 65,
("ImageNet16-120", "tss"): 36,
("ImageNet16-120", "sss"): 40,
}
y_max_s = {
("cifar10", "tss"): 94.3,
("cifar10", "sss"): 93.3,
("cifar100", "tss"): 72.5,
("cifar100", "sss"): 70.5,
("ImageNet16-120", "tss"): 46,
("ImageNet16-120", "sss"): 46,
}
x_axis_s = {
("cifar10", "tss"): 200,
("cifar10", "sss"): 200,
("cifar100", "tss"): 400,
("cifar100", "sss"): 400,
("ImageNet16-120", "tss"): 1200,
("ImageNet16-120", "sss"): 600,
}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_curve(api, vis_save_dir, search_space):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, dataset):
xdataset, max_time = dataset.split("-T")
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
total_tickets = 150
time_tickets = [
float(i) / total_tickets * int(max_time) for i in range(total_tickets)
]
colors = ["b", "g", "c", "m", "y"]
ax.set_xlim(0, x_axis_s[(xdataset, search_space)])
ax.set_ylim(
y_min_s[(xdataset, search_space)], y_max_s[(xdataset, search_space)]
)
for idx, (alg, data) in enumerate(alg2data.items()):
accuracies = []
for ticket in time_tickets:
accuracy, accuracy_std = query_performance(api, data, xdataset, ticket)
accuracies.append(accuracy)
valid_str, test_str = show_valid_test(api, data, xdataset)
# print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std))
print(
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(
time_string(), alg, valid_str, test_str
)
)
alg2accuracies[alg] = accuracies
ax.plot(
[x / 100 for x in time_tickets],
accuracies,
c=colors[idx],
label="{:}".format(alg),
)
ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize)
ax.set_ylabel(
"Test accuracy on {:}".format(name2label[xdataset]), fontsize=LabelSize
)
ax.set_title(
"Searching results on {:}".format(name2label[xdataset]),
fontsize=LabelSize + 4,
)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
# datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
if search_space == "tss":
datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T120000"]
elif search_space == "sss":
datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T60000"]
else:
raise ValueError("Unknown search space: {:}".format(search_space))
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print("sub-plot {:} on {:} done.".format(dataset, search_space))
save_path = (vis_save_dir / "{:}-curve.png".format(search_space)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
choices=["tss", "sss"],
help="Choose the search space.",
)
args = parser.parse_args()
save_dir = Path(args.save_dir)
api = create(None, args.search_space, fast_mode=True, verbose=False)
visualize_curve(api, save_dir, args.search_space)

View File

@@ -0,0 +1,250 @@
###############################################################
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
# The code to draw Figure 7 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-fig7.py #
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from nats_bench import create
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == "size"
if dataset == "cifar10":
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_acc = xinfo["test-accuracy"]
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_acc = xinfo["valid-accuracy"]
else:
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_acc = xinfo["valid-accuracy"]
test_acc = xinfo["test-accuracy"]
return (
valid_acc,
test_acc,
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
)
def fetch_data(
root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"
):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999]
print("\n[fetch data] from {:} on {:}".format(search_space, dataset))
if search_space == "tss":
alg2name["GDAS"] = "gdas-affine0_BN0-None"
alg2name["RSPS"] = "random-affine0_BN0-None"
alg2name["DARTS (1st)"] = "darts-v1-affine0_BN0-None"
alg2name["DARTS (2nd)"] = "darts-v2-affine0_BN0-None"
alg2name["ENAS"] = "enas-affine0_BN0-None"
alg2name["SETN"] = "setn-affine0_BN0-None"
else:
alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(
suffix
)
alg2name[
"masking + Gumbel-Softmax"
] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth")
alg2data = OrderedDict()
for alg, path in alg2path.items():
alg2data[alg], ok_num = [], 0
for seed in seeds:
xpath = path.format(seed)
if os.path.isfile(xpath):
ok_num += 1
else:
print("This is an invalid path : {:}".format(xpath))
continue
data = torch.load(xpath, map_location=torch.device("cpu"))
try:
data = torch.load(
data["last_checkpoint"], map_location=torch.device("cpu")
)
except:
xpath = str(data["last_checkpoint"]).split("E100-")
if len(xpath) == 2 and os.path.isfile(xpath[0] + xpath[1]):
xpath = xpath[0] + xpath[1]
elif "fbv2" in str(data["last_checkpoint"]):
xpath = str(data["last_checkpoint"]).replace("fbv2", "mask_gumbel")
elif "tunas" in str(data["last_checkpoint"]):
xpath = str(data["last_checkpoint"]).replace("tunas", "mask_rl")
else:
raise ValueError(
"Invalid path: {:}".format(data["last_checkpoint"])
)
data = torch.load(xpath, map_location=torch.device("cpu"))
alg2data[alg].append(data["genotypes"])
print("This algorithm : {:} has {:} valid ckps.".format(alg, ok_num))
assert ok_num > 0, "Must have at least 1 valid ckps."
return alg2data
y_min_s = {
("cifar10", "tss"): 90,
("cifar10", "sss"): 92,
("cifar100", "tss"): 65,
("cifar100", "sss"): 65,
("ImageNet16-120", "tss"): 36,
("ImageNet16-120", "sss"): 40,
}
y_max_s = {
("cifar10", "tss"): 94.5,
("cifar10", "sss"): 93.3,
("cifar100", "tss"): 72,
("cifar100", "sss"): 70,
("ImageNet16-120", "tss"): 44,
("ImageNet16-120", "sss"): 46,
}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
name2suffix = {
("sss", "warm"): "-WARM0.3",
("sss", "none"): "-WARMNone",
("tss", "none"): None,
("tss", None): None,
}
def visualize_curve(api, vis_save_dir, search_space, suffix):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, dataset):
print("{:} plot {:10s}".format(time_string(), dataset))
alg2data = fetch_data(
search_space=search_space,
dataset=dataset,
suffix=name2suffix[(search_space, suffix)],
)
alg2accuracies = OrderedDict()
epochs = 100
colors = ["b", "g", "c", "m", "y", "r"]
ax.set_xlim(0, epochs)
# ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
xs, accuracies = [], []
for iepoch in range(epochs + 1):
try:
structures, accs = [_[iepoch - 1] for _ in data], []
except:
raise ValueError(
"This alg {:} on {:} has invalid checkpoints.".format(
alg, dataset
)
)
for structure in structures:
info = api.get_more_info(
structure,
dataset=dataset,
hp=90 if api.search_space_name == "size" else 200,
is_random=False,
)
accs.append(info["test-accuracy"])
accuracies.append(sum(accs) / len(accs))
xs.append(iepoch)
alg2accuracies[alg] = accuracies
ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg))
ax.set_xlabel("The searching epoch", fontsize=LabelSize)
ax.set_ylabel(
"Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize
)
ax.set_title(
"Searching results on {:}".format(name2label[dataset]),
fontsize=LabelSize + 4,
)
structures, valid_accs, test_accs = [_[epochs - 1] for _ in data], [], []
print(
"{:} plot alg : {:} -- final {:} architectures.".format(
time_string(), alg, len(structures)
)
)
for arch in structures:
valid_acc, test_acc, _ = get_valid_test_acc(api, arch, dataset)
test_accs.append(test_acc)
valid_accs.append(valid_acc)
print(
"{:} plot alg : {:} -- validation: {:.2f}$\pm${:.2f} -- test: {:.2f}$\pm${:.2f}".format(
time_string(),
alg,
np.mean(valid_accs),
np.std(valid_accs),
np.mean(test_accs),
np.std(test_accs),
)
)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print("sub-plot {:} on {:} done.".format(dataset, search_space))
save_path = (
vis_save_dir / "{:}-ws-{:}-curve.png".format(search_space, suffix)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
args = parser.parse_args()
save_dir = Path(args.save_dir)
api_tss = create(None, "tss", fast_mode=True, verbose=False)
visualize_curve(api_tss, save_dir, "tss", None)
api_sss = create(None, "sss", fast_mode=True, verbose=False)
visualize_curve(api_sss, save_dir, "sss", "warm")
visualize_curve(api_sss, save_dir, "sss", "none")

View File

@@ -0,0 +1,232 @@
###############################################################
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
# The code to draw Figure 6 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-fig8.py #
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from nats_bench import create
plt.rcParams.update(
{"text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"]}
)
## for Palatino and other serif fonts use:
plt.rcParams.update(
{
"text.usetex": True,
"font.family": "serif",
"font.serif": ["Palatino"],
}
)
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2all = OrderedDict()
# alg2name['REINFORCE'] = 'REINFORCE-0.01'
# alg2name['RANDOM'] = 'RANDOM'
# alg2name['BOHB'] = 'BOHB'
if search_space == "tss":
hp = "$\mathcal{H}^{1}$"
if dataset == "cifar10":
suffixes = ["-T1200000", "-T1200000-FULL"]
elif search_space == "sss":
hp = "$\mathcal{H}^{2}$"
if dataset == "cifar10":
suffixes = ["-T200000", "-T200000-FULL"]
else:
raise ValueError("Unkonwn search space: {:}".format(search_space))
alg2all[r"REA ($\mathcal{H}^{0}$)"] = dict(
path=os.path.join(ss_dir, dataset + suffixes[0], "R-EA-SS3", "results.pth"),
color="b",
linestyle="-",
)
alg2all[r"REA ({:})".format(hp)] = dict(
path=os.path.join(ss_dir, dataset + suffixes[1], "R-EA-SS3", "results.pth"),
color="b",
linestyle="--",
)
for alg, xdata in alg2all.items():
data = torch.load(xdata["path"])
for index, info in data.items():
info["time_w_arch"] = [
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
)
xdata["data"] = data
return alg2all
def query_performance(api, data, dataset, ticket):
results, is_size_space = [], api.search_space_name == "size"
for i, info in data.items():
time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(
arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
info_b = api.get_more_info(
arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"]
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (
ticket - time_a
) / (time_b - time_a) * accuracy_b
results.append(interplate)
# return sum(results) / len(results)
return np.mean(results), np.std(results)
y_min_s = {
("cifar10", "tss"): 91,
("cifar10", "sss"): 91,
("cifar100", "tss"): 65,
("cifar100", "sss"): 65,
("ImageNet16-120", "tss"): 36,
("ImageNet16-120", "sss"): 40,
}
y_max_s = {
("cifar10", "tss"): 94.5,
("cifar10", "sss"): 93.5,
("cifar100", "tss"): 72.5,
("cifar100", "sss"): 70.5,
("ImageNet16-120", "tss"): 46,
("ImageNet16-120", "sss"): 46,
}
x_axis_s = {
("cifar10", "tss"): 1200000,
("cifar10", "sss"): 200000,
("cifar100", "tss"): 400,
("cifar100", "sss"): 400,
("ImageNet16-120", "tss"): 1200,
("ImageNet16-120", "sss"): 600,
}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
spaces2latex = {
"tss": r"$\mathcal{S}_{t}$",
"sss": r"$\mathcal{S}_{s}$",
}
# FuncFormatter can be used as a decorator
@ticker.FuncFormatter
def major_formatter(x, pos):
if x == 0:
return "0"
else:
return "{:.2f}e5".format(x / 1e5)
def visualize_curve(api_dict, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5000, 2000
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 28, 28
def sub_plot_fn(ax, search_space, dataset):
max_time = x_axis_s[(dataset, search_space)]
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
total_tickets = 200
time_tickets = [
float(i) / total_tickets * int(max_time) for i in range(total_tickets)
]
ax.set_xlim(0, x_axis_s[(dataset, search_space)])
ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for tick in ax.get_xticklabels():
tick.set_rotation(25)
tick.set_fontsize(LabelSize - 6)
for tick in ax.get_yticklabels():
tick.set_fontsize(LabelSize - 6)
ax.xaxis.set_major_formatter(major_formatter)
for idx, (alg, xdata) in enumerate(alg2data.items()):
accuracies = []
for ticket in time_tickets:
# import pdb; pdb.set_trace()
accuracy, accuracy_std = query_performance(
api_dict[search_space], xdata["data"], dataset, ticket
)
accuracies.append(accuracy)
# print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std))
print(
"{:} plot alg : {:10s} on {:}".format(time_string(), alg, search_space)
)
alg2accuracies[alg] = accuracies
ax.plot(
time_tickets,
accuracies,
c=xdata["color"],
linestyle=xdata["linestyle"],
label="{:}".format(alg),
)
ax.set_xlabel("Estimated wall-clock time", fontsize=LabelSize)
ax.set_ylabel("Test accuracy", fontsize=LabelSize)
ax.set_title(
r"Results on {:} over {:}".format(
name2label[dataset], spaces2latex[search_space]
),
fontsize=LabelSize,
)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 2, figsize=figsize)
sub_plot_fn(axs[0], "tss", "cifar10")
sub_plot_fn(axs[1], "sss", "cifar10")
save_path = (vis_save_dir / "full-curve.png").resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos-vs-h",
help="Folder to save checkpoints and log.",
)
args = parser.parse_args()
save_dir = Path(args.save_dir)
api_tss = create(None, "tss", fast_mode=True, verbose=False)
api_sss = create(None, "sss", fast_mode=True, verbose=False)
visualize_curve(dict(tss=api_tss, sss=api_sss), save_dir)

View File

@@ -0,0 +1,185 @@
###############################################################
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
# The code to draw Figure 2 / 3 / 4 / 5 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-ranks.py #
###############################################################
import os, sys, time, torch, argparse
import scipy
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from xautodl.models import get_cell_based_tiny_net
from nats_bench import create
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
vis_save_dir = vis_save_dir.resolve()
print(
"{:} start to visualize {:} with top-{:} information".format(
time_string(), search_space, topk
)
)
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "cache-{:}-info.pth".format(search_space)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
if not cache_file_path.exists():
api = create(None, search_space, fast_mode=False, verbose=False)
all_infos = OrderedDict()
for index in range(len(api)):
all_info = OrderedDict()
for dataset in datasets:
info_less = api.get_more_info(index, dataset, hp="12", is_random=False)
info_more = api.get_more_info(
index, dataset, hp=api.full_train_epochs, is_random=False
)
all_info[dataset] = dict(
less=info_less["test-accuracy"], more=info_more["test-accuracy"]
)
all_infos[index] = all_info
torch.save(all_infos, cache_file_path)
print("{:} save all cache data into {:}".format(time_string(), cache_file_path))
else:
api = create(None, search_space, fast_mode=True, verbose=False)
all_infos = torch.load(cache_file_path)
dpi, width, height = 250, 5000, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
def sub_plot_fn(ax, dataset, indicator):
performances = []
# pickup top 10% architectures
for _index in range(len(api)):
performances.append((all_infos[_index][dataset][indicator], _index))
performances = sorted(performances, reverse=True)
performances = performances[: int(len(api) * topk * 0.01)]
selected_indexes = [x[1] for x in performances]
print(
"{:} plot {:10s} with {:}, {:} architectures".format(
time_string(), dataset, indicator, len(selected_indexes)
)
)
standard_scores = []
random_scores = []
for idx in selected_indexes:
standard_scores.append(
api.get_more_info(
idx,
dataset,
hp=api.full_train_epochs if indicator == "more" else "12",
is_random=False,
)["test-accuracy"]
)
random_scores.append(
api.get_more_info(
idx,
dataset,
hp=api.full_train_epochs if indicator == "more" else "12",
is_random=True,
)["test-accuracy"]
)
indexes = list(range(len(selected_indexes)))
standard_indexes = sorted(indexes, key=lambda i: standard_scores[i])
random_indexes = sorted(indexes, key=lambda i: random_scores[i])
random_labels = []
for idx in standard_indexes:
random_labels.append(random_indexes.index(idx))
for tick in ax.get_xticklabels():
tick.set_fontsize(LabelSize - 3)
for tick in ax.get_yticklabels():
tick.set_rotation(25)
tick.set_fontsize(LabelSize - 3)
ax.set_xlim(0, len(indexes))
ax.set_ylim(0, len(indexes))
ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3))
ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, random_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="Average Over Multi-Trials",
)
ax.scatter(
[-1],
[-1],
marker="^",
s=100,
c="tab:green",
label="Randomly Selected Trial",
)
coef, p = scipy.stats.kendalltau(standard_scores, random_scores)
ax.set_xlabel(
"architecture ranking in {:}".format(name2label[dataset]),
fontsize=LabelSize,
)
if dataset == "cifar10":
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
ax.legend(loc=4, fontsize=LegendFontsize)
return coef
for dataset, ax in zip(datasets, axs):
rank_coef = sub_plot_fn(ax, dataset, indicator)
print(
"sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(
dataset, search_space, rank_coef
)
)
save_path = (
vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (
vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("Save into {:}".format(save_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench/rank-stability",
help="Folder to save checkpoints and log.",
)
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
for topk in [1, 5, 10, 20]:
visualize_relative_info(to_save_dir, "tss", "more", topk)
visualize_relative_info(to_save_dir, "sss", "less", topk)
print("{:} : complete running this file : {:}".format(time_string(), __file__))

View File

@@ -0,0 +1,191 @@
###############################################################
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
# The code to draw some results in Table 4 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-table.py #
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from nats_bench import create
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name["REA"] = "R-EA-SS3"
alg2name["REINFORCE"] = "REINFORCE-0.01"
alg2name["RANDOM"] = "RANDOM"
alg2name["BOHB"] = "BOHB"
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth")
assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg])
alg2data = OrderedDict()
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info["time_w_arch"] = [
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
)
alg2data[alg] = data
return alg2data
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == "size"
if dataset == "cifar10":
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
test_acc = xinfo["test-accuracy"]
xinfo = api.get_more_info(
arch,
dataset="cifar10-valid",
hp=90 if is_size_space else 200,
is_random=False,
)
valid_acc = xinfo["valid-accuracy"]
else:
xinfo = api.get_more_info(
arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
valid_acc = xinfo["valid-accuracy"]
test_acc = xinfo["test-accuracy"]
return (
valid_acc,
test_acc,
"validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc),
)
def show_valid_test(api, arch):
is_size_space = api.search_space_name == "size"
final_str = ""
for dataset in ["cifar10", "cifar100", "ImageNet16-120"]:
valid_acc, test_acc, perf_str = get_valid_test_acc(api, arch, dataset)
final_str += "{:} : {:}\n".format(dataset, perf_str)
return final_str
def find_best_valid(api, dataset):
all_valid_accs, all_test_accs = [], []
for index, arch in enumerate(api):
valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)
all_valid_accs.append((index, valid_acc))
all_test_accs.append((index, test_acc))
best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0]
best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]
print("-" * 50 + "{:10s}".format(dataset) + "-" * 50)
print(
"Best ({:}) architecture on validation: {:}".format(
best_valid_index, api[best_valid_index]
)
)
print(
"Best ({:}) architecture on test: {:}".format(
best_test_index, api[best_test_index]
)
)
_, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)
print("using validation ::: {:}".format(perf_str))
_, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)
print("using test ::: {:}".format(perf_str))
def interplate_fn(xpair1, xpair2, x):
(x1, y1) = xpair1
(x2, y2) = xpair2
return (x2 - x) / (x2 - x1) * y1 + (x - x1) / (x2 - x1) * y2
def query_performance(api, info, dataset, ticket):
info = deepcopy(info)
results, is_size_space = [], api.search_space_name == "size"
time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
v_acc_a, t_acc_a, _ = get_valid_test_acc(api, arch_a, dataset)
v_acc_b, t_acc_b, _ = get_valid_test_acc(api, arch_b, dataset)
v_acc = interplate_fn((time_a, v_acc_a), (time_b, v_acc_b), ticket)
t_acc = interplate_fn((time_a, t_acc_a), (time_b, t_acc_b), ticket)
# if True:
# interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b
# results.append(interplate)
# return sum(results) / len(results)
return v_acc, t_acc
def show_multi_trial(search_space):
api = create(None, search_space, fast_mode=True, verbose=False)
def show(dataset):
print("show {:} on {:} done.".format(dataset, search_space))
xdataset, max_time = dataset.split("-T")
alg2data = fetch_data(search_space=search_space, dataset=dataset)
for idx, (alg, data) in enumerate(alg2data.items()):
valid_accs, test_accs = [], []
for _, x in data.items():
v_acc, t_acc = query_performance(api, x, xdataset, float(max_time))
valid_accs.append(v_acc)
test_accs.append(t_acc)
valid_str = "{:.2f}$\pm${:.2f}".format(
np.mean(valid_accs), np.std(valid_accs)
)
test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs))
print(
"{:} plot alg : {:10s} | validation = {:} | test = {:}".format(
time_string(), alg, valid_str, test_str
)
)
if search_space == "tss":
datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T120000"]
elif search_space == "sss":
datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T60000"]
else:
raise ValueError("Unknown search space: {:}".format(search_space))
for dataset in datasets:
show(dataset)
print("{:} complete show multi-trial results.\n".format(time_string()))
if __name__ == "__main__":
show_multi_trial("tss")
show_multi_trial("sss")
api_tss = create(None, "tss", fast_mode=False, verbose=False)
resnet = "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
resnet_index = api_tss.query_index_by_arch(resnet)
print(show_valid_test(api_tss, resnet_index))
for dataset in ["cifar10", "cifar100", "ImageNet16-120"]:
find_best_valid(api_tss, dataset)
largest = "64:64:64:64:64"
largest_index = api_sss.query_index_by_arch(largest)
print(show_valid_test(api_sss, largest_index))
for dataset in ["cifar10", "cifar100", "ImageNet16-120"]:
find_best_valid(api_sss, dataset)

View File

@@ -0,0 +1,486 @@
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
##############################################################################
# This file is used to train (all) architecture candidate in the size search #
# space in NATS-Bench (sss) with different hyper-parameters. #
# When use mode=new, it will automatically detect whether the checkpoint of #
# a trial exists, if so, it will skip this trial. When use mode=cover, it #
# will ignore the (possible) existing checkpoint, run each trial, and save. #
# (NOTE): the topology for all candidates in sss is fixed as: ######################
# |nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| #
###################################################################################################
# Please use the script of scripts/NATS-Bench/train-shapes.sh to run. #
##############################################################################
import os, sys, time, torch, argparse
from typing import List, Text, Dict, Any
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
from xautodl.config_utils import dict2config, load_config
from xautodl.procedures import bench_evaluate_for_seed
from xautodl.procedures import get_machine_info
from xautodl.datasets import get_datasets
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
from xautodl.utils import split_str2indexes
def evaluate_all_datasets(
channels: Text,
datasets: List[Text],
xpaths: List[Text],
splits: List[Text],
config_path: Text,
seed: int,
workers: int,
logger,
):
machine_info = get_machine_info()
all_infos = {"info": machine_info}
all_dataset_keys = []
# look all the dataset
for dataset, xpath, split in zip(datasets, xpaths, splits):
# the train and valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configuration
if dataset == "cifar10" or dataset == "cifar100":
split_info = load_config(
"configs/nas-benchmark/cifar-split.txt", None, None
)
elif dataset.startswith("ImageNet16"):
split_info = load_config(
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
)
else:
raise ValueError("invalid dataset : {:}".format(dataset))
config = load_config(
config_path, dict(class_num=class_num, xshape=xshape), logger
)
# check whether use the splitted validation set
if bool(split):
assert dataset == "cifar10"
ValLoaders = {
"ori-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
}
assert len(train_data) == len(split_info.train) + len(
split_info.valid
), "invalid length : {:} vs {:} + {:}".format(
len(train_data), len(split_info.train), len(split_info.valid)
)
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
num_workers=workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
num_workers=workers,
pin_memory=True,
)
ValLoaders["x-valid"] = valid_loader
else:
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
if dataset == "cifar10":
ValLoaders = {"ori-test": valid_loader}
elif dataset == "cifar100":
cifar100_splits = load_config(
"configs/nas-benchmark/cifar100-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
}
elif dataset == "ImageNet16-120":
imagenet16_splits = load_config(
"configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
}
else:
raise ValueError("invalid dataset : {:}".format(dataset))
dataset_key = "{:}".format(dataset)
if bool(split):
dataset_key = dataset_key + "-valid"
logger.log(
"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
dataset_key,
len(train_data),
len(valid_data),
len(train_loader),
len(valid_loader),
config.batch_size,
)
)
logger.log(
"Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)
)
for key, value in ValLoaders.items():
logger.log(
"Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))
)
# arch-index= 9930, arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|
# this genotype is the architecture with the highest accuracy on CIFAR-100 validation set
genotype = "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|"
arch_config = dict2config(
dict(
name="infer.shape.tiny",
channels=channels,
genotype=genotype,
num_classes=class_num,
),
None,
)
results = bench_evaluate_for_seed(
arch_config, config, train_loader, ValLoaders, seed, logger
)
all_infos[dataset_key] = results
all_dataset_keys.append(dataset_key)
all_infos["all_dataset_keys"] = all_dataset_keys
return all_infos
def main(
save_dir: Path,
workers: int,
datasets: List[Text],
xpaths: List[Text],
splits: List[int],
seeds: List[int],
nets: List[str],
opt_config: Dict[Text, Any],
to_evaluate_indexes: tuple,
cover_mode: bool,
):
log_dir = save_dir / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(str(log_dir), os.getpid(), False)
logger.log("xargs : seeds = {:}".format(seeds))
logger.log("xargs : cover_mode = {:}".format(cover_mode))
logger.log("-" * 100)
logger.log(
"Start evaluating range =: {:06d} - {:06d}".format(
min(to_evaluate_indexes), max(to_evaluate_indexes)
)
+ "({:} in total) / {:06d} with cover-mode={:}".format(
len(to_evaluate_indexes), len(nets), cover_mode
)
)
for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
logger.log(
"--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}".format(
i, len(datasets), dataset, xpath, split
)
)
logger.log("--->>> optimization config : {:}".format(opt_config))
start_time, epoch_time = time.time(), AverageMeter()
for i, index in enumerate(to_evaluate_indexes):
channelstr = nets[index]
logger.log(
"\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format(
time_string(),
i,
len(to_evaluate_indexes),
index,
len(nets),
seeds,
"-" * 15,
)
)
logger.log("{:} {:} {:}".format("-" * 15, channelstr, "-" * 15))
# test this arch on different datasets with different seeds
has_continue = False
for seed in seeds:
to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed)
if to_save_name.exists():
if cover_mode:
logger.log(
"Find existing file : {:}, remove it before evaluation".format(
to_save_name
)
)
os.remove(str(to_save_name))
else:
logger.log(
"Find existing file : {:}, skip this evaluation".format(
to_save_name
)
)
has_continue = True
continue
results = evaluate_all_datasets(
channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger
)
torch.save(results, to_save_name)
logger.log(
"\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format(
time_string(),
i,
len(to_evaluate_indexes),
index,
len(nets),
seeds,
to_save_name,
)
)
# measure elapsed time
if not has_continue:
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True)
)
logger.log(
"This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))
)
logger.log("{:}".format("*" * 100))
logger.log(
"{:} {:74s} {:}".format(
"*" * 10,
"{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}".format(
i, len(to_evaluate_indexes), index, len(nets), need_time
),
"*" * 10,
)
)
logger.log("{:}".format("*" * 100))
logger.close()
def traverse_net(candidates: List[int], N: int):
nets = [""]
for i in range(N):
new_nets = []
for net in nets:
for C in candidates:
new_nets.append(str(C) if net == "" else "{:}:{:}".format(net, C))
nets = new_nets
return nets
def filter_indexes(xlist, mode, save_dir, seeds):
all_indexes = []
for index in xlist:
if mode == "cover":
all_indexes.append(index)
else:
for seed in seeds:
temp_path = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed)
if not temp_path.exists():
all_indexes.append(index)
break
print(
"{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total".format(
time_string(), len(all_indexes), len(xlist)
)
)
SLURM_PROCID, SLURM_NTASKS = "SLURM_PROCID", "SLURM_NTASKS"
if SLURM_PROCID in os.environ and SLURM_NTASKS in os.environ: # run on the slurm
proc_id, ntasks = int(os.environ[SLURM_PROCID]), int(os.environ[SLURM_NTASKS])
assert 0 <= proc_id < ntasks, "invalid proc_id {:} vs ntasks {:}".format(
proc_id, ntasks
)
scales = [int(float(i) / ntasks * len(all_indexes)) for i in range(ntasks)] + [
len(all_indexes)
]
per_job = []
for i in range(ntasks):
xs, xe = min(max(scales[i], 0), len(all_indexes) - 1), min(
max(scales[i + 1] - 1, 0), len(all_indexes) - 1
)
per_job.append((xs, xe))
for i, srange in enumerate(per_job):
print(" -->> {:2d}/{:02d} : {:}".format(i, ntasks, srange))
current_range = per_job[proc_id]
all_indexes = [
all_indexes[i] for i in range(current_range[0], current_range[1] + 1)
]
# set the device id
device = proc_id % torch.cuda.device_count()
torch.cuda.set_device(device)
print(" set the device id = {:}".format(device))
print(
"{:} [FILTER-INDEXES] : after filtering there are {:} architectures in total".format(
time_string(), len(all_indexes)
)
)
return all_indexes
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (size search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--mode",
type=str,
required=True,
choices=["new", "cover"],
help="The script mode.",
)
parser.add_argument(
"--save_dir",
type=str,
default="output/NATS-Bench-size",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--candidateC",
type=int,
nargs="+",
default=[8, 16, 24, 32, 40, 48, 56, 64],
help=".",
)
parser.add_argument(
"--num_layers", type=int, default=5, help="The number of layers in a network."
)
parser.add_argument("--check_N", type=int, default=32768, help="For safety.")
# use for train the model
parser.add_argument(
"--workers",
type=int,
default=8,
help="The number of data loading workers (default: 2)",
)
parser.add_argument(
"--srange", type=str, required=True, help="The range of models to be evaluated"
)
parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.")
parser.add_argument(
"--xpaths", type=str, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--splits", type=int, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--hyper",
type=str,
default="12",
choices=["01", "12", "90"],
help="The tag for hyper-parameters.",
)
parser.add_argument(
"--seeds", type=int, nargs="+", help="The range of models to be evaluated"
)
args = parser.parse_args()
nets = traverse_net(args.candidateC, args.num_layers)
if len(nets) != args.check_N:
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper)
if not os.path.isfile(opt_config):
raise ValueError("{:} is not a file.".format(opt_config))
save_dir = Path(args.save_dir) / "raw-data-{:}".format(args.hyper)
save_dir.mkdir(parents=True, exist_ok=True)
to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5)
if not len(args.seeds):
raise ValueError("invalid length of seeds args: {:}".format(args.seeds))
if not (len(args.datasets) == len(args.xpaths) == len(args.splits)):
raise ValueError(
"invalid infos : {:} vs {:} vs {:}".format(
len(args.datasets), len(args.xpaths), len(args.splits)
)
)
if args.workers <= 0:
raise ValueError("invalid number of workers : {:}".format(args.workers))
target_indexes = filter_indexes(
to_evaluate_indexes, args.mode, save_dir, args.seeds
)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
# torch.set_num_threads(args.workers)
main(
save_dir,
args.workers,
args.datasets,
args.xpaths,
args.splits,
tuple(args.seeds),
nets,
opt_config,
target_indexes,
args.mode == "cover",
)

View File

@@ -0,0 +1,696 @@
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
##############################################################################
# This file is used to train (all) architecture candidate in the topology #
# search space in NATS-Bench (tss) with different hyper-parameters. #
# When use mode=new, it will automatically detect whether the checkpoint of #
# a trial exists, if so, it will skip this trial. When use mode=cover, it #
# will ignore the (possible) existing checkpoint, run each trial, and save. #
##############################################################################
# Please use the script of scripts/NATS-Bench/train-topology.sh to run. #
# bash scripts/NATS-Bench/train-topology.sh 00000-15624 12 777 #
# bash scripts/NATS-Bench/train-topology.sh 00000-15624 200 '777 888 999' #
# #
################ #
# [Deprecated Function: Generate the meta information] #
# python ./exps/NATS-Bench/main-tss.py --mode meta #
##############################################################################
import os, sys, time, torch, random, argparse
from typing import List, Text, Dict, Any
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
from xautodl.config_utils import dict2config, load_config
from xautodl.procedures import bench_evaluate_for_seed
from xautodl.procedures import get_machine_info
from xautodl.datasets import get_datasets
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
from xautodl.utils import split_str2indexes
def evaluate_all_datasets(
arch: Text,
datasets: List[Text],
xpaths: List[Text],
splits: List[Text],
config_path: Text,
seed: int,
raw_arch_config,
workers,
logger,
):
machine_info, raw_arch_config = get_machine_info(), deepcopy(raw_arch_config)
all_infos = {"info": machine_info}
all_dataset_keys = []
# look all the datasets
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configuration
if dataset == "cifar10" or dataset == "cifar100":
split_info = load_config(
"configs/nas-benchmark/cifar-split.txt", None, None
)
elif dataset.startswith("ImageNet16"):
split_info = load_config(
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
)
else:
raise ValueError("invalid dataset : {:}".format(dataset))
config = load_config(
config_path, dict(class_num=class_num, xshape=xshape), logger
)
# check whether use splited validation set
if bool(split):
assert dataset == "cifar10"
ValLoaders = {
"ori-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
}
assert len(train_data) == len(split_info.train) + len(
split_info.valid
), "invalid length : {:} vs {:} + {:}".format(
len(train_data), len(split_info.train), len(split_info.valid)
)
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
num_workers=workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
num_workers=workers,
pin_memory=True,
)
ValLoaders["x-valid"] = valid_loader
else:
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
if dataset == "cifar10":
ValLoaders = {"ori-test": valid_loader}
elif dataset == "cifar100":
cifar100_splits = load_config(
"configs/nas-benchmark/cifar100-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
}
elif dataset == "ImageNet16-120":
imagenet16_splits = load_config(
"configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
}
else:
raise ValueError("invalid dataset : {:}".format(dataset))
dataset_key = "{:}".format(dataset)
if bool(split):
dataset_key = dataset_key + "-valid"
logger.log(
"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
dataset_key,
len(train_data),
len(valid_data),
len(train_loader),
len(valid_loader),
config.batch_size,
)
)
logger.log(
"Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)
)
for key, value in ValLoaders.items():
logger.log(
"Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))
)
arch_config = dict2config(
dict(
name="infer.tiny",
C=raw_arch_config["channel"],
N=raw_arch_config["num_cells"],
genotype=arch,
num_classes=config.class_num,
),
None,
)
results = bench_evaluate_for_seed(
arch_config, config, train_loader, ValLoaders, seed, logger
)
all_infos[dataset_key] = results
all_dataset_keys.append(dataset_key)
all_infos["all_dataset_keys"] = all_dataset_keys
return all_infos
def main(
save_dir: Path,
workers: int,
datasets: List[Text],
xpaths: List[Text],
splits: List[int],
seeds: List[int],
nets: List[str],
opt_config: Dict[Text, Any],
to_evaluate_indexes: tuple,
cover_mode: bool,
arch_config: Dict[Text, Any],
):
log_dir = save_dir / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(str(log_dir), os.getpid(), False)
logger.log("xargs : seeds = {:}".format(seeds))
logger.log("xargs : cover_mode = {:}".format(cover_mode))
logger.log("-" * 100)
logger.log(
"Start evaluating range =: {:06d} - {:06d}".format(
min(to_evaluate_indexes), max(to_evaluate_indexes)
)
+ "({:} in total) / {:06d} with cover-mode={:}".format(
len(to_evaluate_indexes), len(nets), cover_mode
)
)
for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
logger.log(
"--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}".format(
i, len(datasets), dataset, xpath, split
)
)
logger.log("--->>> optimization config : {:}".format(opt_config))
start_time, epoch_time = time.time(), AverageMeter()
for i, index in enumerate(to_evaluate_indexes):
arch = nets[index]
logger.log(
"\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format(
time_string(),
i,
len(to_evaluate_indexes),
index,
len(nets),
seeds,
"-" * 15,
)
)
logger.log("{:} {:} {:}".format("-" * 15, arch, "-" * 15))
# test this arch on different datasets with different seeds
has_continue = False
for seed in seeds:
to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed)
if to_save_name.exists():
if cover_mode:
logger.log(
"Find existing file : {:}, remove it before evaluation".format(
to_save_name
)
)
os.remove(str(to_save_name))
else:
logger.log(
"Find existing file : {:}, skip this evaluation".format(
to_save_name
)
)
has_continue = True
continue
results = evaluate_all_datasets(
CellStructure.str2structure(arch),
datasets,
xpaths,
splits,
opt_config,
seed,
arch_config,
workers,
logger,
)
torch.save(results, to_save_name)
logger.log(
"\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format(
time_string(),
i,
len(to_evaluate_indexes),
index,
len(nets),
seeds,
to_save_name,
)
)
# measure elapsed time
if not has_continue:
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True)
)
logger.log(
"This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))
)
logger.log("{:}".format("*" * 100))
logger.log(
"{:} {:74s} {:}".format(
"*" * 10,
"{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}".format(
i, len(to_evaluate_indexes), index, len(nets), need_time
),
"*" * 10,
)
)
logger.log("{:}".format("*" * 100))
logger.close()
def train_single_model(
save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config
):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True
# torch.set_num_threads(workers)
save_dir = (
Path(save_dir)
/ "specifics"
/ "{:}-{:}-{:}-{:}".format(
"LESS" if use_less else "FULL",
model_str,
arch_config["channel"],
arch_config["num_cells"],
)
)
logger = Logger(str(save_dir), 0, False)
if model_str in CellArchitectures:
arch = CellArchitectures[model_str]
logger.log(
"The model string is found in pre-defined architecture dict : {:}".format(
model_str
)
)
else:
try:
arch = CellStructure.str2structure(model_str)
except:
raise ValueError(
"Invalid model string : {:}. It can not be found or parsed.".format(
model_str
)
)
assert arch.check_valid_op(
get_search_spaces("cell", "full")
), "{:} has the invalid op.".format(arch)
logger.log("Start train-evaluate {:}".format(arch.tostr()))
logger.log("arch_config : {:}".format(arch_config))
start_time, seed_time = time.time(), AverageMeter()
for _is, seed in enumerate(seeds):
logger.log(
"\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format(
_is, len(seeds), seed
)
)
to_save_name = save_dir / "seed-{:04d}.pth".format(seed)
if to_save_name.exists():
logger.log(
"Find the existing file {:}, directly load!".format(to_save_name)
)
checkpoint = torch.load(to_save_name)
else:
logger.log(
"Does not find the existing file {:}, train and evaluate!".format(
to_save_name
)
)
checkpoint = evaluate_all_datasets(
arch,
datasets,
xpaths,
splits,
use_less,
seed,
arch_config,
workers,
logger,
)
torch.save(checkpoint, to_save_name)
# log information
logger.log("{:}".format(checkpoint["info"]))
all_dataset_keys = checkpoint["all_dataset_keys"]
for dataset_key in all_dataset_keys:
logger.log(
"\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)
)
dataset_info = checkpoint[dataset_key]
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
logger.log(
"Flops = {:} MB, Params = {:} MB".format(
dataset_info["flop"], dataset_info["param"]
)
)
logger.log("config : {:}".format(dataset_info["config"]))
logger.log(
"Training State (finish) = {:}".format(dataset_info["finish-train"])
)
last_epoch = dataset_info["total_epoch"] - 1
train_acc1es, train_acc5es = (
dataset_info["train_acc1es"],
dataset_info["train_acc5es"],
)
valid_acc1es, valid_acc5es = (
dataset_info["valid_acc1es"],
dataset_info["valid_acc5es"],
)
logger.log(
"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format(
train_acc1es[last_epoch],
train_acc5es[last_epoch],
100 - train_acc1es[last_epoch],
valid_acc1es[last_epoch],
valid_acc5es[last_epoch],
100 - valid_acc1es[last_epoch],
)
)
# measure elapsed time
seed_time.update(time.time() - start_time)
start_time = time.time()
need_time = "Time Left: {:}".format(
convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)
)
logger.log(
"\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format(
_is, len(seeds), seed, need_time
)
)
logger.close()
def generate_meta_info(save_dir, max_node, divide=40):
aa_nas_bench_ss = get_search_spaces("cell", "nas-bench-201")
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print(
"There are {:} archs vs {:}.".format(
len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)
)
)
random.seed(88) # please do not change this line for reproducibility
random.shuffle(archs)
# to test fixed-random shuffle
# print ('arch [0] : {:}\n---->>>> {:}'.format( archs[0], archs[0].tostr() ))
# print ('arch [9] : {:}\n---->>>> {:}'.format( archs[9], archs[9].tostr() ))
assert (
archs[0].tostr()
== "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|"
), "please check the 0-th architecture : {:}".format(archs[0])
assert (
archs[9].tostr()
== "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
), "please check the 9-th architecture : {:}".format(archs[9])
assert (
archs[123].tostr()
== "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
), "please check the 123-th architecture : {:}".format(archs[123])
total_arch = len(archs)
num = 50000
indexes_5W = list(range(num))
random.seed(1021)
random.shuffle(indexes_5W)
train_split = sorted(list(set(indexes_5W[: num // 2])))
valid_split = sorted(list(set(indexes_5W[num // 2 :])))
assert len(train_split) + len(valid_split) == num
assert (
train_split[0] == 0
and train_split[10] == 26
and train_split[111] == 203
and valid_split[0] == 1
and valid_split[10] == 18
and valid_split[111] == 242
), "{:} {:} {:} - {:} {:} {:}".format(
train_split[0],
train_split[10],
train_split[111],
valid_split[0],
valid_split[10],
valid_split[111],
)
splits = {num: {"train": train_split, "valid": valid_split}}
info = {
"archs": [x.tostr() for x in archs],
"total": total_arch,
"max_node": max_node,
"splits": splits,
}
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
save_name = save_dir / "meta-node-{:}.pth".format(max_node)
assert not save_name.exists(), "{:} already exist".format(save_name)
torch.save(info, save_name)
print("save the meta file into {:}".format(save_name))
def traverse_net(max_node):
aa_nas_bench_ss = get_search_spaces("cell", "nats-bench")
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print(
"There are {:} archs vs {:}.".format(
len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)
)
)
random.seed(88) # please do not change this line for reproducibility
random.shuffle(archs)
assert (
archs[0].tostr()
== "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|"
), "please check the 0-th architecture : {:}".format(archs[0])
assert (
archs[9].tostr()
== "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
), "please check the 9-th architecture : {:}".format(archs[9])
assert (
archs[123].tostr()
== "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
), "please check the 123-th architecture : {:}".format(archs[123])
return [x.tostr() for x in archs]
def filter_indexes(xlist, mode, save_dir, seeds):
all_indexes = []
for index in xlist:
if mode == "cover":
all_indexes.append(index)
else:
for seed in seeds:
temp_path = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed)
if not temp_path.exists():
all_indexes.append(index)
break
print(
"{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total".format(
time_string(), len(all_indexes), len(xlist)
)
)
return all_indexes
if __name__ == "__main__":
# mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()]
parser = argparse.ArgumentParser(
description="NATS-Bench (topology search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--mode", type=str, required=True, help="The script mode.")
parser.add_argument(
"--save_dir",
type=str,
default="output/NATS-Bench-topology",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--max_node",
type=int,
default=4,
help="The maximum node in a cell (please do not change it).",
)
# use for train the model
parser.add_argument(
"--workers",
type=int,
default=8,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--srange", type=str, required=True, help="The range of models to be evaluated"
)
parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.")
parser.add_argument(
"--xpaths", type=str, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--splits", type=int, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--hyper",
type=str,
default="12",
choices=["01", "12", "200"],
help="The tag for hyper-parameters.",
)
parser.add_argument(
"--seeds", type=int, nargs="+", help="The range of models to be evaluated"
)
parser.add_argument(
"--channel", type=int, default=16, help="The number of channels."
)
parser.add_argument(
"--num_cells", type=int, default=5, help="The number of cells in one stage."
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
args = parser.parse_args()
assert args.mode in ["meta", "new", "cover"] or args.mode.startswith(
"specific-"
), "invalid mode : {:}".format(args.mode)
if args.mode == "meta":
generate_meta_info(args.save_dir, args.max_node)
elif args.mode.startswith("specific"):
assert len(args.mode.split("-")) == 2, "invalid mode : {:}".format(args.mode)
model_str = args.mode.split("-")[1]
train_single_model(
args.save_dir,
args.workers,
args.datasets,
args.xpaths,
args.splits,
args.use_less > 0,
tuple(args.seeds),
model_str,
{"channel": args.channel, "num_cells": args.num_cells},
)
else:
nets = traverse_net(args.max_node)
if len(nets) != args.check_N:
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper)
if not os.path.isfile(opt_config):
raise ValueError("{:} is not a file.".format(opt_config))
save_dir = Path(args.save_dir) / "raw-data-{:}".format(args.hyper)
save_dir.mkdir(parents=True, exist_ok=True)
to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5)
if not len(args.seeds):
raise ValueError("invalid length of seeds args: {:}".format(args.seeds))
if not (len(args.datasets) == len(args.xpaths) == len(args.splits)):
raise ValueError(
"invalid infos : {:} vs {:} vs {:}".format(
len(args.datasets), len(args.xpaths), len(args.splits)
)
)
if args.workers < 0:
raise ValueError("invalid number of workers : {:}".format(args.workers))
target_indexes = filter_indexes(
to_evaluate_indexes, args.mode, save_dir, args.seeds
)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
# torch.set_num_threads(args.workers if args.workers > 0 else 1)
main(
save_dir,
args.workers,
args.datasets,
args.xpaths,
args.splits,
tuple(args.seeds),
nets,
opt_config,
target_indexes,
args.mode == "cover",
{
"name": "infer.tiny",
"channel": args.channel,
"num_cells": args.num_cells,
},
)

View File

@@ -0,0 +1,59 @@
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
##############################################################################
# python ./exps/NATS-Bench/show-dataset.py #
##############################################################################
import os, sys, time, torch, random, argparse
from typing import List, Text, Dict, Any
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from xautodl.config_utils import dict2config, load_config
from xautodl.datasets import get_datasets
from nats_bench import create
def show_imagenet_16_120(dataset_dir=None):
if dataset_dir is None:
torch_home_dir = (
os.environ["TORCH_HOME"]
if "TORCH_HOME" in os.environ
else os.path.join(os.environ["HOME"], ".torch")
)
dataset_dir = os.path.join(torch_home_dir, "cifar.python", "ImageNet16")
train_data, valid_data, xshape, class_num = get_datasets(
"ImageNet16-120", dataset_dir, -1
)
split_info = load_config(
"configs/nas-benchmark/ImageNet16-120-split.txt", None, None
)
print("=" * 10 + " ImageNet-16-120 " + "=" * 10)
print("Training Data: {:}".format(train_data))
print("Evaluation Data: {:}".format(valid_data))
print("Hold-out training: {:} images.".format(len(split_info.train)))
print("Hold-out valid : {:} images.".format(len(split_info.valid)))
if __name__ == "__main__":
# show_imagenet_16_120()
api_nats_tss = create(None, "tss", fast_mode=True, verbose=True)
valid_acc_12e = []
test_acc_12e = []
test_acc_200e = []
for index in range(10000):
info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="12")
valid_acc_12e.append(
info["valid-accuracy"]
) # the validation accuracy after training the model by 12 epochs
test_acc_12e.append(
info["test-accuracy"]
) # the test accuracy after training the model by 12 epochs
info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="200")
test_acc_200e.append(
info["test-accuracy"]
) # the test accuracy after training the model by 200 epochs (which I reported in the paper)

View File

@@ -0,0 +1,389 @@
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# This file is used to re-orangize all checkpoints (created by main-sss.py) #
# into a single benchmark file. Besides, for each trial, we will merge the #
# information of all its trials into a single file. #
# #
# Usage: #
# python exps/NATS-Bench/sss-collect.py #
##############################################################################
import os, re, sys, time, shutil, argparse, collections
import torch
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.config_utils import dict2config
from xautodl.models import CellStructure, get_cell_based_tiny_net
from xautodl.procedures import (
bench_pure_evaluate as pure_evaluate,
get_nas_bench_loaders,
)
from xautodl.utils import get_md5_file
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
NATS_SSS_BASE_NAME = "NATS-sss-v1_0" # 2020.08.28
def account_one_arch(
arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]
) -> ArchResults:
information = ArchResults(arch_index, arch_str)
for checkpoint_path in checkpoints:
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
except:
raise ValueError(
"This checkpoint failed to be loaded : {:}".format(checkpoint_path)
)
used_seed = checkpoint_path.name.split("-")[-1].split(".")[0]
ok_dataset = 0
for dataset in datasets:
if dataset not in checkpoint:
print(
"Can not find {:} in arch-{:} from {:}".format(
dataset, arch_index, checkpoint_path
)
)
continue
else:
ok_dataset += 1
results = checkpoint[dataset]
assert results[
"finish-train"
], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
arch_index, used_seed, dataset, checkpoint_path
)
arch_config = {
"name": "infer.shape.tiny",
"channels": arch_str,
"arch_str": arch_str,
"genotype": results["arch_config"]["genotype"],
"class_num": results["arch_config"]["num_classes"],
}
xresult = ResultsCount(
dataset,
results["net_state_dict"],
results["train_acc1es"],
results["train_losses"],
results["param"],
results["flop"],
arch_config,
used_seed,
results["total_epoch"],
None,
)
xresult.update_train_info(
results["train_acc1es"],
results["train_acc5es"],
results["train_losses"],
results["train_times"],
)
xresult.update_eval(
results["valid_acc1es"], results["valid_losses"], results["valid_times"]
)
information.update(dataset, int(used_seed), xresult)
if ok_dataset < len(datasets):
raise ValueError(
"{:} does find enought data : {:} vs {:}".format(
checkpoint_path, ok_dataset, len(datasets)
)
)
return information
def correct_time_related_info(hp2info: Dict[Text, ArchResults]):
# calibrate the latency based on the number of epochs = 01, since they are trained on the same machine.
x1 = hp2info["01"].get_metrics("cifar10-valid", "x-valid")["all_time"] / 98
x2 = hp2info["01"].get_metrics("cifar10-valid", "ori-test")["all_time"] / 40
cifar010_latency = (x1 + x2) / 2
for hp, arch_info in hp2info.items():
arch_info.reset_latency("cifar10-valid", None, cifar010_latency)
arch_info.reset_latency("cifar10", None, cifar010_latency)
# hp2info['01'].get_latency('cifar10')
x1 = hp2info["01"].get_metrics("cifar100", "ori-test")["all_time"] / 40
x2 = hp2info["01"].get_metrics("cifar100", "x-test")["all_time"] / 20
x3 = hp2info["01"].get_metrics("cifar100", "x-valid")["all_time"] / 20
cifar100_latency = (x1 + x2 + x3) / 3
for hp, arch_info in hp2info.items():
arch_info.reset_latency("cifar100", None, cifar100_latency)
x1 = hp2info["01"].get_metrics("ImageNet16-120", "ori-test")["all_time"] / 24
x2 = hp2info["01"].get_metrics("ImageNet16-120", "x-test")["all_time"] / 12
x3 = hp2info["01"].get_metrics("ImageNet16-120", "x-valid")["all_time"] / 12
image_latency = (x1 + x2 + x3) / 3
for hp, arch_info in hp2info.items():
arch_info.reset_latency("ImageNet16-120", None, image_latency)
# CIFAR10 VALID
train_per_epoch_time = list(
hp2info["01"].query("cifar10-valid", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time = [], []
for key, value in hp2info["01"].query("cifar10-valid", 777).eval_times.items():
if key.startswith("ori-test@"):
eval_ori_test_time.append(value)
elif key.startswith("x-valid@"):
eval_x_valid_time.append(value)
else:
raise ValueError("-- {:} --".format(key))
eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time)
eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time)
for hp, arch_info in hp2info.items():
arch_info.reset_pseudo_train_times("cifar10-valid", None, train_per_epoch_time)
arch_info.reset_pseudo_eval_times(
"cifar10-valid", None, "x-valid", eval_x_valid_time
)
arch_info.reset_pseudo_eval_times(
"cifar10-valid", None, "ori-test", eval_ori_test_time
)
# CIFAR10
train_per_epoch_time = list(
hp2info["01"].query("cifar10", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time = []
for key, value in hp2info["01"].query("cifar10", 777).eval_times.items():
if key.startswith("ori-test@"):
eval_ori_test_time.append(value)
else:
raise ValueError("-- {:} --".format(key))
eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time)
for hp, arch_info in hp2info.items():
arch_info.reset_pseudo_train_times("cifar10", None, train_per_epoch_time)
arch_info.reset_pseudo_eval_times(
"cifar10", None, "ori-test", eval_ori_test_time
)
# CIFAR100
train_per_epoch_time = list(
hp2info["01"].query("cifar100", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], []
for key, value in hp2info["01"].query("cifar100", 777).eval_times.items():
if key.startswith("ori-test@"):
eval_ori_test_time.append(value)
elif key.startswith("x-valid@"):
eval_x_valid_time.append(value)
elif key.startswith("x-test@"):
eval_x_test_time.append(value)
else:
raise ValueError("-- {:} --".format(key))
eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time)
eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time)
eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time)
for hp, arch_info in hp2info.items():
arch_info.reset_pseudo_train_times("cifar100", None, train_per_epoch_time)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "x-valid", eval_x_valid_time
)
arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_x_test_time)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "ori-test", eval_ori_test_time
)
# ImageNet16-120
train_per_epoch_time = list(
hp2info["01"].query("ImageNet16-120", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], []
for key, value in hp2info["01"].query("ImageNet16-120", 777).eval_times.items():
if key.startswith("ori-test@"):
eval_ori_test_time.append(value)
elif key.startswith("x-valid@"):
eval_x_valid_time.append(value)
elif key.startswith("x-test@"):
eval_x_test_time.append(value)
else:
raise ValueError("-- {:} --".format(key))
eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time)
eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time)
eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time)
for hp, arch_info in hp2info.items():
arch_info.reset_pseudo_train_times("ImageNet16-120", None, train_per_epoch_time)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "x-valid", eval_x_valid_time
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "x-test", eval_x_test_time
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "ori-test", eval_ori_test_time
)
return hp2info
def simplify(save_dir, save_name, nets, total):
hps, seeds = ["01", "12", "90"], set()
for hp in hps:
sub_save_dir = save_dir / "raw-data-{:}".format(hp)
ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth")))
seed2names = defaultdict(list)
for ckp in ckps:
parts = re.split("-|\.", ckp.name)
seed2names[parts[3]].append(ckp.name)
print("DIR : {:}".format(sub_save_dir))
nums = []
for seed, xlist in seed2names.items():
seeds.add(seed)
nums.append(len(xlist))
print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist)))
assert (
len(nets) == total == max(nums)
), "there are some missed files : {:} vs {:}".format(max(nums), total)
print("{:} start simplify the checkpoint.".format(time_string()))
datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")
# Create the directory to save the processed data
# full_save_dir contains all benchmark files with trained weights.
# simplify_save_dir contains all benchmark files without trained weights.
full_save_dir = save_dir / (save_name + "-FULL")
simple_save_dir = save_dir / (save_name + "-SIMPLIFY")
full_save_dir.mkdir(parents=True, exist_ok=True)
simple_save_dir.mkdir(parents=True, exist_ok=True)
# all data in memory
arch2infos, evaluated_indexes = dict(), set()
end_time, arch_time = time.time(), AverageMeter()
for index in tqdm(range(total)):
arch_str = nets[index]
hp2info = OrderedDict()
full_save_path = full_save_dir / "{:06d}.pickle".format(index)
simple_save_path = simple_save_dir / "{:06d}.pickle".format(index)
for hp in hps:
sub_save_dir = save_dir / "raw-data-{:}".format(hp)
ckps = [
sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed)
for seed in seeds
]
ckps = [x for x in ckps if x.exists()]
if len(ckps) == 0:
raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp))
arch_info = account_one_arch(index, arch_str, ckps, datasets)
hp2info[hp] = arch_info
hp2info = correct_time_related_info(hp2info)
evaluated_indexes.add(index)
hp2info["01"].clear_params() # to save some spaces...
to_save_data = OrderedDict(
{
"01": hp2info["01"].state_dict(),
"12": hp2info["12"].state_dict(),
"90": hp2info["90"].state_dict(),
}
)
pickle_save(to_save_data, str(full_save_path))
for hp in hps:
hp2info[hp].clear_params()
to_save_data = OrderedDict(
{
"01": hp2info["01"].state_dict(),
"12": hp2info["12"].state_dict(),
"90": hp2info["90"].state_dict(),
}
)
pickle_save(to_save_data, str(simple_save_path))
arch2infos[index] = to_save_data
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = "{:}".format(
convert_secs2time(arch_time.avg * (total - index - 1), True)
)
# print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
print("{:} {:} done.".format(time_string(), save_name))
final_infos = {
"meta_archs": nets,
"total_archs": total,
"arch2infos": arch2infos,
"evaluated_indexes": evaluated_indexes,
}
save_file_name = save_dir / "{:}.pickle".format(save_name)
pickle_save(final_infos, str(save_file_name))
# move the benchmark file to a new path
hd5sum = get_md5_file(str(save_file_name) + ".pbz2")
hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(NATS_SSS_BASE_NAME, hd5sum)
shutil.move(str(save_file_name) + ".pbz2", hd5_file_name)
print(
"Save {:} / {:} architecture results into {:} -> {:}.".format(
len(evaluated_indexes), total, save_file_name, hd5_file_name
)
)
# move the directory to a new path
hd5_full_save_dir = save_dir / "{:}-{:}-full".format(NATS_SSS_BASE_NAME, hd5sum)
hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(NATS_SSS_BASE_NAME, hd5sum)
shutil.move(full_save_dir, hd5_full_save_dir)
shutil.move(simple_save_dir, hd5_simple_save_dir)
# save the meta information for simple and full
final_infos["arch2infos"] = None
final_infos["evaluated_indexes"] = set()
pickle_save(final_infos, str(hd5_full_save_dir / "meta.pickle"))
pickle_save(final_infos, str(hd5_simple_save_dir / "meta.pickle"))
def traverse_net(candidates: List[int], N: int):
nets = [""]
for i in range(N):
new_nets = []
for net in nets:
for C in candidates:
new_nets.append(str(C) if net == "" else "{:}:{:}".format(net, C))
nets = new_nets
return nets
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (size search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--base_save_dir",
type=str,
default="./output/NATS-Bench-size",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument(
"--candidateC",
type=int,
nargs="+",
default=[8, 16, 24, 32, 40, 48, 56, 64],
help=".",
)
parser.add_argument(
"--num_layers", type=int, default=5, help="The number of layers in a network."
)
parser.add_argument("--check_N", type=int, default=32768, help="For safety.")
parser.add_argument(
"--save_name", type=str, default="process", help="The save directory."
)
args = parser.parse_args()
nets = traverse_net(args.candidateC, args.num_layers)
if len(nets) != args.check_N:
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
save_dir = Path(args.base_save_dir)
simplify(save_dir, args.save_name, nets, args.check_N)

View File

@@ -0,0 +1,103 @@
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# Usage: python exps/NATS-Bench/sss-file-manager.py --mode check #
##############################################################################
import os, sys, time, torch, argparse
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from xautodl.config_utils import dict2config, load_config
from xautodl.procedures import bench_evaluate_for_seed
from xautodl.procedures import get_machine_info
from xautodl.datasets import get_datasets
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
def obtain_valid_ckp(save_dir: Text, total: int):
possible_seeds = [777, 888, 999]
seed2ckps = defaultdict(list)
miss2ckps = defaultdict(list)
for i in range(total):
for seed in possible_seeds:
path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed))
if os.path.exists(path):
seed2ckps[seed].append(i)
else:
miss2ckps[seed].append(i)
for seed, xlist in seed2ckps.items():
print(
"[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format(
save_dir, seed, len(xlist), total, total - len(xlist), total
)
)
return dict(seed2ckps), dict(miss2ckps)
def copy_data(source_dir, target_dir, meta_path):
target_dir = Path(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
miss2ckps = torch.load(meta_path)["miss2ckps"]
s2t = {}
for seed, xlist in miss2ckps.items():
for i in xlist:
file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed)
source_path = os.path.join(source_dir, file_name)
target_path = os.path.join(target_dir, file_name)
if os.path.exists(source_path):
s2t[source_path] = target_path
print(
"Map from {:} to {:}, find {:} missed ckps.".format(
source_dir, target_dir, len(s2t)
)
)
for s, t in s2t.items():
copyfile(s, t)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (size search space) file manager.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--mode",
type=str,
required=True,
choices=["check", "copy"],
help="The script mode.",
)
parser.add_argument(
"--save_dir",
type=str,
default="output/NATS-Bench-size",
help="Folder to save checkpoints and log.",
)
parser.add_argument("--check_N", type=int, default=32768, help="For safety.")
# use for train the model
args = parser.parse_args()
possible_configs = ["01", "12", "90"]
if args.mode == "check":
for config in possible_configs:
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N)
torch.save(
dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps),
"{:}/meta-{:}.pth".format(args.save_dir, config),
)
elif args.mode == "copy":
for config in possible_configs:
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config)
cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config)
if os.path.exists(cur_meta_path):
copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
else:
print("Do not find : {:}".format(cur_meta_path))
else:
raise ValueError("invalid mode : {:}".format(args.mode))

View File

@@ -0,0 +1,111 @@
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# Usage: python exps/NATS-Bench/test-nats-api.py #
##############################################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from xautodl.models import get_cell_based_tiny_net, CellStructure
from nats_bench import create
def test_api(api, sss_or_tss=True):
print("{:} start testing the api : {:}".format(time_string(), api))
api.clear_params(12)
api.reload(index=12)
# Query the informations of 1113-th architecture
info_strs = api.query_info_str_by_arch(1113)
print(info_strs)
info = api.query_by_index(113)
print("{:}\n".format(info))
info = api.query_by_index(113, "cifar100")
print("{:}\n".format(info))
info = api.query_meta_info_by_index(115, "90" if sss_or_tss else "200")
print("{:}\n".format(info))
for dataset in ["cifar10", "cifar100", "ImageNet16-120"]:
for xset in ["train", "test", "valid"]:
best_index, highest_accuracy = api.find_best(dataset, xset)
print("")
params = api.get_net_param(12, "cifar10", None)
# Obtain the config and create the network
config = api.get_net_config(12, "cifar10")
print("{:}\n".format(config))
network = get_cell_based_tiny_net(config)
network.load_state_dict(next(iter(params.values())))
# Obtain the cost information
info = api.get_cost_info(12, "cifar10")
print("{:}\n".format(info))
info = api.get_latency(12, "cifar10")
print("{:}\n".format(info))
for index in [13, 15, 19, 200]:
info = api.get_latency(index, "cifar10")
# Count the number of architectures
info = api.statistics("cifar100", "12")
print("{:} statistics results : {:}\n".format(time_string(), info))
# Show the information of the 123-th architecture
api.show(123)
# Obtain both cost and performance information
info = api.get_more_info(1234, "cifar10")
print("{:}\n".format(info))
print("{:} finish testing the api : {:}".format(time_string(), api))
if not sss_or_tss:
arch_str = "|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|"
matrix = api.str2matrix(arch_str)
print("Compute the adjacency matrix of {:}".format(arch_str))
print(matrix)
info = api.simulate_train_eval(123, "cifar10")
print("simulate_train_eval : {:}\n\n".format(info))
if __name__ == "__main__":
# api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
for fast_mode in [True, False]:
for verbose in [True, False]:
api_nats_tss = create(None, "tss", fast_mode=fast_mode, verbose=True)
print(
"{:} create with fast_mode={:} and verbose={:}".format(
time_string(), fast_mode, verbose
)
)
test_api(api_nats_tss, False)
del api_nats_tss
gc.collect()
for fast_mode in [True, False]:
for verbose in [True, False]:
print(
"{:} create with fast_mode={:} and verbose={:}".format(
time_string(), fast_mode, verbose
)
)
api_nats_sss = create(None, "size", fast_mode=fast_mode, verbose=True)
print("{:} --->>> {:}".format(time_string(), api_nats_sss))
test_api(api_nats_sss, True)
del api_nats_sss
gc.collect()

View File

@@ -0,0 +1,179 @@
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# This file is used to re-orangize all checkpoints (created by main-tss.py) #
# into a single benchmark file. Besides, for each trial, we will merge the #
# information of all its trials into a single file. #
# #
# Usage: #
# python exps/NATS-Bench/tss-collect-patcher.py #
##############################################################################
import os, re, sys, time, shutil, random, argparse, collections
import numpy as np
from copy import deepcopy
import torch
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.config_utils import load_config, dict2config
from xautodl.datasets import get_datasets
from xautodl.models import CellStructure, get_cell_based_tiny_net, get_search_spaces
from xautodl.procedures import (
bench_pure_evaluate as pure_evaluate,
get_nas_bench_loaders,
)
from xautodl.utils import get_md5_file
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from nas_201_api import NASBench201API
NATS_TSS_BASE_NAME = "NATS-tss-v1_0" # 2020.08.28
def simplify(save_dir, save_name, nets, total, sup_config):
hps, seeds = ["12", "200"], set()
for hp in hps:
sub_save_dir = save_dir / "raw-data-{:}".format(hp)
ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth")))
seed2names = defaultdict(list)
for ckp in ckps:
parts = re.split("-|\.", ckp.name)
seed2names[parts[3]].append(ckp.name)
print("DIR : {:}".format(sub_save_dir))
nums = []
for seed, xlist in seed2names.items():
seeds.add(seed)
nums.append(len(xlist))
print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist)))
assert (
len(nets) == total == max(nums)
), "there are some missed files : {:} vs {:}".format(max(nums), total)
print("{:} start simplify the checkpoint.".format(time_string()))
datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")
# Create the directory to save the processed data
# full_save_dir contains all benchmark files with trained weights.
# simplify_save_dir contains all benchmark files without trained weights.
full_save_dir = save_dir / (save_name + "-FULL")
simple_save_dir = save_dir / (save_name + "-SIMPLIFY")
full_save_dir.mkdir(parents=True, exist_ok=True)
simple_save_dir.mkdir(parents=True, exist_ok=True)
# all data in memory
arch2infos, evaluated_indexes = dict(), set()
end_time, arch_time = time.time(), AverageMeter()
# save the meta information
for index in tqdm(range(total)):
arch_str = nets[index]
hp2info = OrderedDict()
simple_save_path = simple_save_dir / "{:06d}.pickle".format(index)
arch2infos[index] = pickle_load(simple_save_path)
evaluated_indexes.add(index)
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = "{:}".format(
convert_secs2time(arch_time.avg * (total - index - 1), True)
)
# print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
print("{:} {:} done.".format(time_string(), save_name))
final_infos = {
"meta_archs": nets,
"total_archs": total,
"arch2infos": arch2infos,
"evaluated_indexes": evaluated_indexes,
}
save_file_name = save_dir / "{:}.pickle".format(save_name)
pickle_save(final_infos, str(save_file_name))
# move the benchmark file to a new path
hd5sum = get_md5_file(str(save_file_name) + ".pbz2")
hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(str(save_file_name) + ".pbz2", hd5_file_name)
print(
"Save {:} / {:} architecture results into {:} -> {:}.".format(
len(evaluated_indexes), total, save_file_name, hd5_file_name
)
)
# move the directory to a new path
hd5_full_save_dir = save_dir / "{:}-{:}-full".format(NATS_TSS_BASE_NAME, hd5sum)
hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(full_save_dir, hd5_full_save_dir)
shutil.move(simple_save_dir, hd5_simple_save_dir)
def traverse_net(max_node):
aa_nas_bench_ss = get_search_spaces("cell", "nats-bench")
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print(
"There are {:} archs vs {:}.".format(
len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)
)
)
random.seed(88) # please do not change this line for reproducibility
random.shuffle(archs)
assert (
archs[0].tostr()
== "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|"
), "please check the 0-th architecture : {:}".format(archs[0])
assert (
archs[9].tostr()
== "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
), "please check the 9-th architecture : {:}".format(archs[9])
assert (
archs[123].tostr()
== "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
), "please check the 123-th architecture : {:}".format(archs[123])
return [x.tostr() for x in archs]
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (topology search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--base_save_dir",
type=str,
default="./output/NATS-Bench-topology",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument(
"--max_node", type=int, default=4, help="The maximum node in a cell."
)
parser.add_argument(
"--channel", type=int, default=16, help="The number of channels."
)
parser.add_argument(
"--num_cells", type=int, default=5, help="The number of cells in one stage."
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
parser.add_argument(
"--save_name", type=str, default="process", help="The save directory."
)
args = parser.parse_args()
nets = traverse_net(args.max_node)
if len(nets) != args.check_N:
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
save_dir = Path(args.base_save_dir)
simplify(
save_dir,
args.save_name,
nets,
args.check_N,
{"name": "infer.tiny", "channel": args.channel, "num_cells": args.num_cells},
)

View File

@@ -0,0 +1,461 @@
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# This file is used to re-orangize all checkpoints (created by main-tss.py) #
# into a single benchmark file. Besides, for each trial, we will merge the #
# information of all its trials into a single file. #
# #
# Usage: #
# python exps/NATS-Bench/tss-collect.py #
##############################################################################
import os, re, sys, time, shutil, random, argparse, collections
import numpy as np
from copy import deepcopy
import torch
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.config_utils import load_config, dict2config
from xautodl.datasets import get_datasets
from xautodl.models import CellStructure, get_cell_based_tiny_net, get_search_spaces
from xautodl.procedures import (
bench_pure_evaluate as pure_evaluate,
get_nas_bench_loaders,
)
from xautodl.utils import get_md5_file
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from nas_201_api import NASBench201API
api = NASBench201API(
"{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"])
)
NATS_TSS_BASE_NAME = "NATS-tss-v1_0" # 2020.08.28
def create_result_count(
used_seed: int,
dataset: Text,
arch_config: Dict[Text, Any],
results: Dict[Text, Any],
dataloader_dict: Dict[Text, Any],
) -> ResultsCount:
xresult = ResultsCount(
dataset,
results["net_state_dict"],
results["train_acc1es"],
results["train_losses"],
results["param"],
results["flop"],
arch_config,
used_seed,
results["total_epoch"],
None,
)
net_config = dict2config(
{
"name": "infer.tiny",
"C": arch_config["channel"],
"N": arch_config["num_cells"],
"genotype": CellStructure.str2structure(arch_config["arch_str"]),
"num_classes": arch_config["class_num"],
},
None,
)
if "train_times" in results: # new version
xresult.update_train_info(
results["train_acc1es"],
results["train_acc5es"],
results["train_losses"],
results["train_times"],
)
xresult.update_eval(
results["valid_acc1es"], results["valid_losses"], results["valid_times"]
)
else:
network = get_cell_based_tiny_net(net_config)
network.load_state_dict(xresult.get_net_param())
if dataset == "cifar10-valid":
xresult.update_OLD_eval(
"x-valid", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda()
)
xresult.update_OLD_eval(
"ori-test",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
xresult.update_latency(latencies)
elif dataset == "cifar10":
xresult.update_OLD_eval(
"ori-test", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()
)
xresult.update_latency(latencies)
elif dataset == "cifar100" or dataset == "ImageNet16-120":
xresult.update_OLD_eval(
"ori-test", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda()
)
xresult.update_OLD_eval(
"x-valid",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()
)
xresult.update_OLD_eval(
"x-test",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
xresult.update_latency(latencies)
else:
raise ValueError("invalid dataset name : {:}".format(dataset))
return xresult
def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict):
information = ArchResults(arch_index, arch_str)
for checkpoint_path in checkpoints:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
used_seed = checkpoint_path.name.split("-")[-1].split(".")[0]
ok_dataset = 0
for dataset in datasets:
if dataset not in checkpoint:
print(
"Can not find {:} in arch-{:} from {:}".format(
dataset, arch_index, checkpoint_path
)
)
continue
else:
ok_dataset += 1
results = checkpoint[dataset]
assert results[
"finish-train"
], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
arch_index, used_seed, dataset, checkpoint_path
)
arch_config = {
"channel": results["channel"],
"num_cells": results["num_cells"],
"arch_str": arch_str,
"class_num": results["config"]["class_num"],
}
xresult = create_result_count(
used_seed, dataset, arch_config, results, dataloader_dict
)
information.update(dataset, int(used_seed), xresult)
if ok_dataset == 0:
raise ValueError("{:} does not find any data".format(checkpoint_path))
return information
def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResults]):
# calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
cifar010_latency = (
api.get_latency(arch_index, "cifar10-valid", hp="200")
+ api.get_latency(arch_index, "cifar10", hp="200")
) / 2
cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200")
image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200")
for hp, arch_info in arch_infos.items():
arch_info.reset_latency("cifar10-valid", None, cifar010_latency)
arch_info.reset_latency("cifar10", None, cifar010_latency)
arch_info.reset_latency("cifar100", None, cifar100_latency)
arch_info.reset_latency("ImageNet16-120", None, image_latency)
train_per_epoch_time = list(
arch_infos["12"].query("cifar10-valid", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time = [], []
for key, value in arch_infos["12"].query("cifar10-valid", 777).eval_times.items():
if key.startswith("ori-test@"):
eval_ori_test_time.append(value)
elif key.startswith("x-valid@"):
eval_x_valid_time.append(value)
else:
raise ValueError("-- {:} --".format(key))
eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(
np.mean(eval_x_valid_time)
)
nums = {
"ImageNet16-120-train": 151700,
"ImageNet16-120-valid": 3000,
"ImageNet16-120-test": 6000,
"cifar10-valid-train": 25000,
"cifar10-valid-valid": 25000,
"cifar10-train": 50000,
"cifar10-test": 10000,
"cifar100-train": 50000,
"cifar100-test": 10000,
"cifar100-valid": 5000,
}
eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (
nums["cifar10-valid-valid"] + nums["cifar10-test"]
)
for hp, arch_info in arch_infos.items():
arch_info.reset_pseudo_train_times(
"cifar10-valid",
None,
train_per_epoch_time
/ nums["cifar10-valid-train"]
* nums["cifar10-valid-train"],
)
arch_info.reset_pseudo_train_times(
"cifar10",
None,
train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"],
)
arch_info.reset_pseudo_train_times(
"cifar100",
None,
train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"],
)
arch_info.reset_pseudo_train_times(
"ImageNet16-120",
None,
train_per_epoch_time
/ nums["cifar10-valid-train"]
* nums["ImageNet16-120-train"],
)
arch_info.reset_pseudo_eval_times(
"cifar10-valid",
None,
"x-valid",
eval_per_sample * nums["cifar10-valid-valid"],
)
arch_info.reset_pseudo_eval_times(
"cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"]
)
arch_info.reset_pseudo_eval_times(
"cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"]
)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"]
)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"]
)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"]
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120",
None,
"x-valid",
eval_per_sample * nums["ImageNet16-120-valid"],
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120",
None,
"x-test",
eval_per_sample * nums["ImageNet16-120-valid"],
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120",
None,
"ori-test",
eval_per_sample * nums["ImageNet16-120-test"],
)
return arch_infos
def simplify(save_dir, save_name, nets, total, sup_config):
dataloader_dict = get_nas_bench_loaders(6)
hps, seeds = ["12", "200"], set()
for hp in hps:
sub_save_dir = save_dir / "raw-data-{:}".format(hp)
ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth")))
seed2names = defaultdict(list)
for ckp in ckps:
parts = re.split("-|\.", ckp.name)
seed2names[parts[3]].append(ckp.name)
print("DIR : {:}".format(sub_save_dir))
nums = []
for seed, xlist in seed2names.items():
seeds.add(seed)
nums.append(len(xlist))
print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist)))
assert (
len(nets) == total == max(nums)
), "there are some missed files : {:} vs {:}".format(max(nums), total)
print("{:} start simplify the checkpoint.".format(time_string()))
datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")
# Create the directory to save the processed data
# full_save_dir contains all benchmark files with trained weights.
# simplify_save_dir contains all benchmark files without trained weights.
full_save_dir = save_dir / (save_name + "-FULL")
simple_save_dir = save_dir / (save_name + "-SIMPLIFY")
full_save_dir.mkdir(parents=True, exist_ok=True)
simple_save_dir.mkdir(parents=True, exist_ok=True)
# all data in memory
arch2infos, evaluated_indexes = dict(), set()
end_time, arch_time = time.time(), AverageMeter()
# save the meta information
temp_final_infos = {
"meta_archs": nets,
"total_archs": total,
"arch2infos": None,
"evaluated_indexes": set(),
}
pickle_save(temp_final_infos, str(full_save_dir / "meta.pickle"))
pickle_save(temp_final_infos, str(simple_save_dir / "meta.pickle"))
for index in tqdm(range(total)):
arch_str = nets[index]
hp2info = OrderedDict()
full_save_path = full_save_dir / "{:06d}.pickle".format(index)
simple_save_path = simple_save_dir / "{:06d}.pickle".format(index)
for hp in hps:
sub_save_dir = save_dir / "raw-data-{:}".format(hp)
ckps = [
sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed)
for seed in seeds
]
ckps = [x for x in ckps if x.exists()]
if len(ckps) == 0:
raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp))
arch_info = account_one_arch(
index, arch_str, ckps, datasets, dataloader_dict
)
hp2info[hp] = arch_info
hp2info = correct_time_related_info(index, hp2info)
evaluated_indexes.add(index)
to_save_data = OrderedDict(
{"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}
)
pickle_save(to_save_data, str(full_save_path))
for hp in hps:
hp2info[hp].clear_params()
to_save_data = OrderedDict(
{"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}
)
pickle_save(to_save_data, str(simple_save_path))
arch2infos[index] = to_save_data
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = "{:}".format(
convert_secs2time(arch_time.avg * (total - index - 1), True)
)
# print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
print("{:} {:} done.".format(time_string(), save_name))
final_infos = {
"meta_archs": nets,
"total_archs": total,
"arch2infos": arch2infos,
"evaluated_indexes": evaluated_indexes,
}
save_file_name = save_dir / "{:}.pickle".format(save_name)
pickle_save(final_infos, str(save_file_name))
# move the benchmark file to a new path
hd5sum = get_md5_file(str(save_file_name) + ".pbz2")
hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(str(save_file_name) + ".pbz2", hd5_file_name)
print(
"Save {:} / {:} architecture results into {:} -> {:}.".format(
len(evaluated_indexes), total, save_file_name, hd5_file_name
)
)
# move the directory to a new path
hd5_full_save_dir = save_dir / "{:}-{:}-full".format(NATS_TSS_BASE_NAME, hd5sum)
hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(full_save_dir, hd5_full_save_dir)
shutil.move(simple_save_dir, hd5_simple_save_dir)
# save the meta information for simple and full
# final_infos['arch2infos'] = None
# final_infos['evaluated_indexes'] = set()
def traverse_net(max_node):
aa_nas_bench_ss = get_search_spaces("cell", "nats-bench")
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print(
"There are {:} archs vs {:}.".format(
len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)
)
)
random.seed(88) # please do not change this line for reproducibility
random.shuffle(archs)
assert (
archs[0].tostr()
== "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|"
), "please check the 0-th architecture : {:}".format(archs[0])
assert (
archs[9].tostr()
== "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
), "please check the 9-th architecture : {:}".format(archs[9])
assert (
archs[123].tostr()
== "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
), "please check the 123-th architecture : {:}".format(archs[123])
return [x.tostr() for x in archs]
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (topology search space)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--base_save_dir",
type=str,
default="./output/NATS-Bench-topology",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument(
"--max_node", type=int, default=4, help="The maximum node in a cell."
)
parser.add_argument(
"--channel", type=int, default=16, help="The number of channels."
)
parser.add_argument(
"--num_cells", type=int, default=5, help="The number of cells in one stage."
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
parser.add_argument(
"--save_name", type=str, default="process", help="The save directory."
)
args = parser.parse_args()
nets = traverse_net(args.max_node)
if len(nets) != args.check_N:
raise ValueError(
"Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)
)
save_dir = Path(args.base_save_dir)
simplify(
save_dir,
args.save_name,
nets,
args.check_N,
{"name": "infer.tiny", "channel": args.channel, "num_cells": args.num_cells},
)

View File

@@ -0,0 +1,105 @@
##############################################################################
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# Usage: python exps/NATS-Bench/tss-file-manager.py --mode check #
##############################################################################
import os, sys, time, torch, argparse
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from xautodl.config_utils import dict2config, load_config
from xautodl.procedures import bench_evaluate_for_seed
from xautodl.procedures import get_machine_info
from xautodl.datasets import get_datasets
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]):
seed2ckps = defaultdict(list)
miss2ckps = defaultdict(list)
for i in range(total):
for seed in possible_seeds:
path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed))
if os.path.exists(path):
seed2ckps[seed].append(i)
else:
miss2ckps[seed].append(i)
for seed, xlist in seed2ckps.items():
print(
"[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format(
save_dir, seed, len(xlist), total, total - len(xlist), total
)
)
return dict(seed2ckps), dict(miss2ckps)
def copy_data(source_dir, target_dir, meta_path):
target_dir = Path(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
miss2ckps = torch.load(meta_path)["miss2ckps"]
s2t = {}
for seed, xlist in miss2ckps.items():
for i in xlist:
file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed)
source_path = os.path.join(source_dir, file_name)
target_path = os.path.join(target_dir, file_name)
if os.path.exists(source_path):
s2t[source_path] = target_path
print(
"Map from {:} to {:}, find {:} missed ckps.".format(
source_dir, target_dir, len(s2t)
)
)
for s, t in s2t.items():
copyfile(s, t)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench (topology search space) file manager.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--mode",
type=str,
required=True,
choices=["check", "copy"],
help="The script mode.",
)
parser.add_argument(
"--save_dir",
type=str,
default="output/NATS-Bench-topology",
help="Folder to save checkpoints and log.",
)
parser.add_argument("--check_N", type=int, default=15625, help="For safety.")
# use for train the model
args = parser.parse_args()
possible_configs = ["12", "200"]
possible_seedss = [[111, 777], [777, 888, 999]]
if args.mode == "check":
for config, possible_seeds in zip(possible_configs, possible_seedss):
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
seed2ckps, miss2ckps = obtain_valid_ckp(
cur_save_dir, args.check_N, possible_seeds
)
torch.save(
dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps),
"{:}/meta-{:}.pth".format(args.save_dir, config),
)
elif args.mode == "copy":
for config in possible_configs:
cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config)
cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config)
if os.path.exists(cur_meta_path):
copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
else:
print("Do not find : {:}".format(cur_meta_path))
else:
raise ValueError("invalid mode : {:}".format(args.mode))