Reformulate via black

This commit is contained in:
D-X-Y
2021-03-17 09:25:58 +00:00
parent a9093e41e1
commit f98edea22a
59 changed files with 12289 additions and 8918 deletions

View File

@@ -12,39 +12,43 @@ from pathlib import Path
from collections import OrderedDict
import matplotlib
import seaborn as sns
matplotlib.use('agg')
matplotlib.use("agg")
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API
from log_utils import time_string
from models import get_cell_based_tiny_net
from utils import weight_watcher
if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument('--api_path' , type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
parser.add_argument('--archive_path', type=str, default=None, help='The path to the NAS-Bench-201 weight dir.')
args = parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument(
"--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir."
)
parser.add_argument("--archive_path", type=str, default=None, help="The path to the NAS-Bench-201 weight dir.")
args = parser.parse_args()
meta_file = Path(args.api_path)
weight_dir = Path(args.archive_path)
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
meta_file = Path(args.api_path)
weight_dir = Path(args.archive_path)
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir)
api = NASBench201API(meta_file, verbose=True)
api = NASBench201API(meta_file, verbose=True)
arch_index = 3 # query the 3-th architecture
api.reload(weight_dir, arch_index) # reload the data of 3-th from archive dir
arch_index = 3 # query the 3-th architecture
api.reload(weight_dir, arch_index) # reload the data of 3-th from archive dir
data = "cifar10" # query the info from CIFAR-10
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(arch_index, hp="200") # all info about this architecture
params = meta_info.get_net_param(data, 888)
data = 'cifar10' # query the info from CIFAR-10
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(arch_index, hp='200') # all info about this architecture
params = meta_info.get_net_param(data, 888)
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)
print('The summary of {:}-th architecture:\n{:}'.format(arch_index, summary))
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)
print("The summary of {:}-th architecture:\n{:}".format(arch_index, summary))

View File

@@ -2,23 +2,27 @@ import sys, time, random, argparse
from copy import deepcopy
import torchvision.models as models
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from utils import get_model_infos
#from models.ImageNet_MobileNetV2 import MobileNetV2
# from models.ImageNet_MobileNetV2 import MobileNetV2
from torchvision.models.mobilenet import MobileNetV2
def main(width_mult):
# model = MobileNetV2(1001, width_mult, 32, 1280, 'InvertedResidual', 0.2)
model = MobileNetV2(width_mult=width_mult)
print(model)
flops, params = get_model_infos(model, (2, 3, 224, 224))
print('FLOPs : {:}'.format(flops))
print('Params : {:}'.format(params))
print('-'*50)
# model = MobileNetV2(1001, width_mult, 32, 1280, 'InvertedResidual', 0.2)
model = MobileNetV2(width_mult=width_mult)
print(model)
flops, params = get_model_infos(model, (2, 3, 224, 224))
print("FLOPs : {:}".format(flops))
print("Params : {:}".format(params))
print("-" * 50)
if __name__ == '__main__':
main(1.0)
main(1.4)
if __name__ == "__main__":
main(1.0)
main(1.4)

View File

@@ -5,110 +5,148 @@ from copy import deepcopy
import torch
import numpy as np
from collections import OrderedDict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API as API
def test_nas_api():
from nas_201_api import ArchResults
xdata = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth')
for key in ['full', 'less']:
print ('\n------------------------- {:} -------------------------'.format(key))
archRes = ArchResults.create_from_state_dict(xdata[key])
print(archRes)
print(archRes.arch_idx_str())
print(archRes.get_dataset_names())
print(archRes.get_comput_costs('cifar10-valid'))
# get the metrics
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False))
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True))
print(archRes.query('cifar10-valid', 777))
from nas_201_api import ArchResults
xdata = torch.load(
"/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth"
)
for key in ["full", "less"]:
print("\n------------------------- {:} -------------------------".format(key))
archRes = ArchResults.create_from_state_dict(xdata[key])
print(archRes)
print(archRes.arch_idx_str())
print(archRes.get_dataset_names())
print(archRes.get_comput_costs("cifar10-valid"))
# get the metrics
print(archRes.get_metrics("cifar10-valid", "x-valid", None, False))
print(archRes.get_metrics("cifar10-valid", "x-valid", None, True))
print(archRes.query("cifar10-valid", 777))
OPS = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3']
COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1']
OPS = ["skip-connect", "conv-1x1", "conv-3x3", "pool-3x3"]
COLORS = ["chartreuse", "cyan", "navyblue", "chocolate1"]
def plot(filename):
from graphviz import Digraph
g = Digraph(
format='png',
edge_attr=dict(fontsize='20', fontname="times"),
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
engine='dot')
g.body.extend(['rankdir=LR'])
from graphviz import Digraph
steps = 5
for i in range(0, steps):
if i == 0:
g.node(str(i), fillcolor='darkseagreen2')
elif i+1 == steps:
g.node(str(i), fillcolor='palegoldenrod')
else: g.node(str(i), fillcolor='lightblue')
g = Digraph(
format="png",
edge_attr=dict(fontsize="20", fontname="times"),
node_attr=dict(
style="filled",
shape="rect",
align="center",
fontsize="20",
height="0.5",
width="0.5",
penwidth="2",
fontname="times",
),
engine="dot",
)
g.body.extend(["rankdir=LR"])
for i in range(1, steps):
for xin in range(i):
op_i = random.randint(0, len(OPS)-1)
#g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
#import pdb; pdb.set_trace()
g.render(filename, cleanup=True, view=False)
steps = 5
for i in range(0, steps):
if i == 0:
g.node(str(i), fillcolor="darkseagreen2")
elif i + 1 == steps:
g.node(str(i), fillcolor="palegoldenrod")
else:
g.node(str(i), fillcolor="lightblue")
for i in range(1, steps):
for xin in range(i):
op_i = random.randint(0, len(OPS) - 1)
# g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
# import pdb; pdb.set_trace()
g.render(filename, cleanup=True, view=False)
def test_auto_grad():
class Net(torch.nn.Module):
def __init__(self, iS):
super(Net, self).__init__()
self.layer = torch.nn.Linear(iS, 1)
def forward(self, inputs):
outputs = self.layer(inputs)
outputs = torch.exp(outputs)
return outputs.mean()
net = Net(10)
inputs = torch.rand(256, 10)
loss = net(inputs)
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
second_order_grads = []
for grads in first_order_grads:
s_grads = torch.autograd.grad(grads, net.parameters())
second_order_grads.append( s_grads )
class Net(torch.nn.Module):
def __init__(self, iS):
super(Net, self).__init__()
self.layer = torch.nn.Linear(iS, 1)
def forward(self, inputs):
outputs = self.layer(inputs)
outputs = torch.exp(outputs)
return outputs.mean()
net = Net(10)
inputs = torch.rand(256, 10)
loss = net(inputs)
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
second_order_grads = []
for grads in first_order_grads:
s_grads = torch.autograd.grad(grads, net.parameters())
second_order_grads.append(s_grads)
def test_one_shot_model(ckpath, use_train):
from models import get_cell_based_tiny_net, get_search_spaces
from datasets import get_datasets, SearchDataset
from config_utils import load_config, dict2config
from utils.nas_utils import evaluate_one_shot
use_train = int(use_train) > 0
#ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
#ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
print ('ckpath : {:}'.format(ckpath))
ckp = torch.load(ckpath)
xargs = ckp['args']
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
#config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
config = load_config('./configs/nas-benchmark/algos/DARTS.config', {'class_num': class_num, 'xshape': xshape}, None)
if xargs.dataset == 'cifar10':
cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
xvalid_data = deepcopy(train_data)
xvalid_data.transform = valid_data.transform
valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True)
else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet))
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': True}, None)
search_model = get_cell_based_tiny_net(model_config)
search_model.load_state_dict( ckp['search_model'] )
search_model = search_model.cuda()
api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth')
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
from models import get_cell_based_tiny_net, get_search_spaces
from datasets import get_datasets, SearchDataset
from config_utils import load_config, dict2config
from utils.nas_utils import evaluate_one_shot
use_train = int(use_train) > 0
# ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
# ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
print("ckpath : {:}".format(ckpath))
ckp = torch.load(ckpath)
xargs = ckp["args"]
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
# config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None)
if xargs.dataset == "cifar10":
cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
xvalid_data = deepcopy(train_data)
xvalid_data.transform = valid_data.transform
valid_loader = torch.utils.data.DataLoader(
xvalid_data,
batch_size=2048,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid),
num_workers=12,
pin_memory=True,
)
else:
raise ValueError("invalid dataset : {:}".format(xargs.dataseet))
search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
{
"name": "SETN",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": True,
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
search_model.load_state_dict(ckp["search_model"])
search_model = search_model.cuda()
api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
if __name__ == '__main__':
#test_nas_api()
#for i in range(200): plot('{:04d}'.format(i))
#test_auto_grad()
test_one_shot_model(sys.argv[1], sys.argv[2])
if __name__ == "__main__":
# test_nas_api()
# for i in range(200): plot('{:04d}'.format(i))
# test_auto_grad()
test_one_shot_model(sys.argv[1], sys.argv[2])

View File

@@ -4,24 +4,28 @@
# python exps/experimental/test-resnest.py
#####################################################
import sys, time, torch, random, argparse
from PIL import ImageFile
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from utils import get_model_infos
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from utils import get_model_infos
torch.hub.list('zhanghang1989/ResNeSt', force_reload=True)
torch.hub.list("zhanghang1989/ResNeSt", force_reload=True)
for model_name, xshape in [('resnest50', (1,3,224,224)),
('resnest101', (1,3,256,256)),
('resnest200', (1,3,320,320)),
('resnest269', (1,3,416,416))]:
# net = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=True)
net = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=False)
print('Model : {:}, input shape : {:}'.format(model_name, xshape))
flops, param = get_model_infos(net, xshape)
print('flops : {:.3f}M'.format(flops))
print('params : {:.3f}M'.format(param))
for model_name, xshape in [
("resnest50", (1, 3, 224, 224)),
("resnest101", (1, 3, 256, 256)),
("resnest200", (1, 3, 320, 320)),
("resnest269", (1, 3, 416, 416)),
]:
# net = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=True)
net = torch.hub.load("zhanghang1989/ResNeSt", model_name, pretrained=False)
print("Model : {:}, input shape : {:}".format(model_name, xshape))
flops, param = get_model_infos(net, xshape)
print("flops : {:.3f}M".format(flops))
print("params : {:.3f}M".format(param))

View File

@@ -15,10 +15,13 @@ from pathlib import Path
from collections import OrderedDict
import matplotlib
import seaborn as sns
matplotlib.use('agg')
matplotlib.use("agg")
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from log_utils import time_string
from nats_bench import create
from models import get_cell_based_tiny_net
@@ -38,111 +41,125 @@ def tostr(accdict, norms):
return ' '.join(xstr)
"""
def evaluate(api, weight_dir, data: str):
print('\nEvaluate dataset={:}'.format(data))
process = psutil.Process(os.getpid())
norms, accuracies = [], []
ok, total = 0, 5000
for idx in range(total):
arch_index = api.random()
api.reload(weight_dir, arch_index)
# compute the weight watcher results
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(arch_index, hp='200' if api.search_space_name == 'topology' else '90')
params = meta_info.get_net_param(data, 888 if api.search_space_name == 'topology' else 777)
with torch.no_grad():
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)
if 'lognorm' not in summary:
print("\nEvaluate dataset={:}".format(data))
process = psutil.Process(os.getpid())
norms, accuracies = [], []
ok, total = 0, 5000
for idx in range(total):
arch_index = api.random()
api.reload(weight_dir, arch_index)
# compute the weight watcher results
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(arch_index, hp="200" if api.search_space_name == "topology" else "90")
params = meta_info.get_net_param(data, 888 if api.search_space_name == "topology" else 777)
with torch.no_grad():
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)
if "lognorm" not in summary:
api.clear_params(arch_index, None)
del net
continue
continue
cur_norm = -summary["lognorm"]
api.clear_params(arch_index, None)
del net ; continue
continue
cur_norm = -summary['lognorm']
api.clear_params(arch_index, None)
if math.isnan(cur_norm):
del net, meta_info
continue
else:
ok += 1
norms.append(cur_norm)
# query the accuracy
info = meta_info.get_metrics(data, 'ori-test', iepoch=None, is_random=888 if api.search_space_name == 'topology' else 777)
accuracies.append(info['accuracy'])
del net, meta_info
# print the information
if idx % 20 == 0:
gc.collect()
print('{:} {:04d}_{:04d}/{:04d} ({:.2f} MB memory)'.format(time_string(), ok, idx, total, process.memory_info().rss / 1e6))
return norms, accuracies
if math.isnan(cur_norm):
del net, meta_info
continue
else:
ok += 1
norms.append(cur_norm)
# query the accuracy
info = meta_info.get_metrics(
data, "ori-test", iepoch=None, is_random=888 if api.search_space_name == "topology" else 777
)
accuracies.append(info["accuracy"])
del net, meta_info
# print the information
if idx % 20 == 0:
gc.collect()
print(
"{:} {:04d}_{:04d}/{:04d} ({:.2f} MB memory)".format(
time_string(), ok, idx, total, process.memory_info().rss / 1e6
)
)
return norms, accuracies
def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
save_dir.mkdir(parents=True, exist_ok=True)
api = create(meta_file, search_space, verbose=False)
datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
print(time_string() + ' ' + '='*50)
for data in datasets:
hps = api.avaliable_hps
for hp in hps:
nums = api.statistics(data, hp=hp)
total = sum([k*v for k, v in nums.items()])
print('Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).'.format(hp, data, total, nums))
print(time_string() + ' ' + '='*50)
save_dir.mkdir(parents=True, exist_ok=True)
api = create(meta_file, search_space, verbose=False)
datasets = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"]
print(time_string() + " " + "=" * 50)
for data in datasets:
hps = api.avaliable_hps
for hp in hps:
nums = api.statistics(data, hp=hp)
total = sum([k * v for k, v in nums.items()])
print("Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(hp, data, total, nums))
print(time_string() + " " + "=" * 50)
norms, accuracies = evaluate(api, weight_dir, xdata)
norms, accuracies = evaluate(api, weight_dir, xdata)
indexes = list(range(len(norms)))
norm_indexes = sorted(indexes, key=lambda i: norms[i])
accy_indexes = sorted(indexes, key=lambda i: accuracies[i])
labels = []
for index in norm_indexes:
labels.append(accy_indexes.index(index))
indexes = list(range(len(norms)))
norm_indexes = sorted(indexes, key=lambda i: norms[i])
accy_indexes = sorted(indexes, key=lambda i: accuracies[i])
labels = []
for index in norm_indexes:
labels.append(accy_indexes.index(index))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
ax.scatter(indexes, labels , marker='*', s=0.5, c='tab:red' , alpha=0.8)
ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Test accuracy')
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='Weight watcher')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking sorted by the test accuracy ', fontsize=LabelSize)
ax.set_ylabel('architecture ranking computed by weight watcher', fontsize=LabelSize)
save_path = (save_dir / '{:}-{:}-test-ww.pdf'.format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (save_dir / '{:}-{:}-test-ww.png'.format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
print('{:} finish this test.'.format(time_string()))
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical")
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
ax.scatter(indexes, labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Test accuracy")
ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="Weight watcher")
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel("architecture ranking sorted by the test accuracy ", fontsize=LabelSize)
ax.set_ylabel("architecture ranking computed by weight watcher", fontsize=LabelSize)
save_path = (save_dir / "{:}-{:}-test-ww.pdf".format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (save_dir / "{:}-{:}-test-ww.png".format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
print("{:} finish this test.".format(time_string()))
if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument('--save_dir', type=str, default='./output/vis-nas-bench/', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--search_space', type=str, default=None, choices=['tss', 'sss'], help='The search space.')
parser.add_argument('--base_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
parser.add_argument('--dataset' , type=str, default=None, help='.')
args = parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument(
"--save_dir",
type=str,
default="./output/vis-nas-bench/",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--search_space", type=str, default=None, choices=["tss", "sss"], help="The search space.")
parser.add_argument(
"--base_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir."
)
parser.add_argument("--dataset", type=str, default=None, help=".")
args = parser.parse_args()
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.base_path + '.pth')
weight_dir = Path(args.base_path + '-full')
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset)
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.base_path + ".pth")
weight_dir = Path(args.base_path + "-full")
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir)
main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset)

View File

@@ -2,31 +2,33 @@ import sys, time, random, argparse
from copy import deepcopy
import torchvision.models as models
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from utils import weight_watcher
def main():
# model = models.vgg19_bn(pretrained=True)
# _, summary = weight_watcher.analyze(model, alphas=False)
# for key, value in summary.items():
# print('{:10s} : {:}'.format(key, value))
# model = models.vgg19_bn(pretrained=True)
# _, summary = weight_watcher.analyze(model, alphas=False)
# for key, value in summary.items():
# print('{:10s} : {:}'.format(key, value))
_, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False)
print('vgg-13 : {:}'.format(summary['lognorm']))
_, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False)
print('vgg-13-BN : {:}'.format(summary['lognorm']))
_, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False)
print('vgg-16 : {:}'.format(summary['lognorm']))
_, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False)
print('vgg-16-BN : {:}'.format(summary['lognorm']))
_, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False)
print('vgg-19 : {:}'.format(summary['lognorm']))
_, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False)
print('vgg-19-BN : {:}'.format(summary['lognorm']))
_, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False)
print("vgg-13 : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False)
print("vgg-13-BN : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False)
print("vgg-16 : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False)
print("vgg-16-BN : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False)
print("vgg-19 : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False)
print("vgg-19-BN : {:}".format(summary["lognorm"]))
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()

View File

@@ -11,122 +11,133 @@ import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name['REA'] = 'R-EA-SS3'
alg2name['REINFORCE'] = 'REINFORCE-0.01'
alg2name['RANDOM'] = 'RANDOM'
alg2name['BOHB'] = 'BOHB'
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth')
assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg])
alg2data = OrderedDict()
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])]
for j, arch in enumerate(info['all_archs']):
assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j)
alg2data[alg] = data
return alg2data
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name["REA"] = "R-EA-SS3"
alg2name["REINFORCE"] = "REINFORCE-0.01"
alg2name["RANDOM"] = "RANDOM"
alg2name["BOHB"] = "BOHB"
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth")
assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg])
alg2data = OrderedDict()
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
)
alg2data[alg] = data
return alg2data
def query_performance(api, data, dataset, ticket):
results, is_size_space = [], api.search_space_name == 'size'
for i, info in data.items():
time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy']
interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b
results.append(interplate)
return sum(results) / len(results)
results, is_size_space = [], api.search_space_name == "size"
for i, info in data.items():
time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"]
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / (
time_b - time_a
) * accuracy_b
results.append(interplate)
return sum(results) / len(results)
y_min_s = {('cifar10', 'tss'): 90,
('cifar10', 'sss'): 92,
('cifar100', 'tss'): 65,
('cifar100', 'sss'): 65,
('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40}
y_min_s = {
("cifar10", "tss"): 90,
("cifar10", "sss"): 92,
("cifar100", "tss"): 65,
("cifar100", "sss"): 65,
("ImageNet16-120", "tss"): 36,
("ImageNet16-120", "sss"): 40,
}
y_max_s = {('cifar10', 'tss'): 94.5,
('cifar10', 'sss'): 93.3,
('cifar100', 'tss'): 72,
('cifar100', 'sss'): 70,
('ImageNet16-120', 'tss'): 44,
('ImageNet16-120', 'sss'): 46}
y_max_s = {
("cifar10", "tss"): 94.5,
("cifar10", "sss"): 93.3,
("cifar100", "tss"): 72,
("cifar100", "sss"): 70,
("ImageNet16-120", "tss"): 44,
("ImageNet16-120", "sss"): 46,
}
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
name2label = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet-16-120'}
def visualize_curve(api, vis_save_dir, search_space, max_time):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
total_tickets = 150
time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)]
colors = ['b', 'g', 'c', 'm', 'y']
ax.set_xlim(0, 200)
ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
print('plot alg : {:}'.format(alg))
accuracies = []
for ticket in time_tickets:
accuracy = query_performance(api, data, dataset, ticket)
accuracies.append(accuracy)
alg2accuracies[alg] = accuracies
ax.plot([x/100 for x in time_tickets], accuracies, c=colors[idx], label='{:}'.format(alg))
ax.set_xlabel('Estimated wall-clock time (1e2 seconds)', fontsize=LabelSize)
ax.set_ylabel('Test accuracy on {:}'.format(name2label[dataset]), fontsize=LabelSize)
ax.set_title('Searching results on {:}'.format(name2label[dataset]), fontsize=LabelSize+4)
ax.legend(loc=4, fontsize=LegendFontsize)
def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
total_tickets = 150
time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)]
colors = ["b", "g", "c", "m", "y"]
ax.set_xlim(0, 200)
ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
print("plot alg : {:}".format(alg))
accuracies = []
for ticket in time_tickets:
accuracy = query_performance(api, data, dataset, ticket)
accuracies.append(accuracy)
alg2accuracies[alg] = accuracies
ax.plot([x / 100 for x in time_tickets], accuracies, c=colors[idx], label="{:}".format(alg))
ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize)
ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize)
ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print('sub-plot {:} on {:} done.'.format(dataset, search_space))
save_path = (vis_save_dir / '{:}-curve.png'.format(search_space)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print("sub-plot {:} on {:} done.".format(dataset, search_space))
save_path = (vis_save_dir / "{:}-curve.png".format(search_space)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
parser.add_argument('--max_time', type=float, default=20000, help='The maximum time budget.')
args = parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
)
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
parser.add_argument("--max_time", type=float, default=20000, help="The maximum time budget.")
args = parser.parse_args()
save_dir = Path(args.save_dir)
save_dir = Path(args.save_dir)
api = create(None, args.search_space, verbose=False)
visualize_curve(api, save_dir, args.search_space, args.max_time)
api = create(None, args.search_space, verbose=False)
visualize_curve(api, save_dir, args.search_space, args.max_time)

View File

@@ -11,132 +11,143 @@ import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
# def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'):
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999]
print('\n[fetch data] from {:} on {:}'.format(search_space, dataset))
if search_space == 'tss':
alg2name['GDAS'] = 'gdas-affine0_BN0-None'
alg2name['RSPS'] = 'random-affine0_BN0-None'
alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None'
alg2name['DARTS (2nd)'] = 'darts-v2-affine0_BN0-None'
alg2name['ENAS'] = 'enas-affine0_BN0-None'
alg2name['SETN'] = 'setn-affine0_BN0-None'
else:
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict()
for alg, path in alg2path.items():
alg2data[alg], ok_num = [], 0
for seed in seeds:
xpath = path.format(seed)
if os.path.isfile(xpath):
ok_num += 1
else:
print('This is an invalid path : {:}'.format(xpath))
continue
data = torch.load(xpath, map_location=torch.device('cpu'))
data = torch.load(data['last_checkpoint'], map_location=torch.device('cpu'))
alg2data[alg].append(data['genotypes'])
print('This algorithm : {:} has {:} valid ckps.'.format(alg, ok_num))
assert ok_num > 0, 'Must have at least 1 valid ckps.'
return alg2data
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999]
print("\n[fetch data] from {:} on {:}".format(search_space, dataset))
if search_space == "tss":
alg2name["GDAS"] = "gdas-affine0_BN0-None"
alg2name["RSPS"] = "random-affine0_BN0-None"
alg2name["DARTS (1st)"] = "darts-v1-affine0_BN0-None"
alg2name["DARTS (2nd)"] = "darts-v2-affine0_BN0-None"
alg2name["ENAS"] = "enas-affine0_BN0-None"
alg2name["SETN"] = "setn-affine0_BN0-None"
else:
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["masking + Gumbel-Softmax"] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth")
alg2data = OrderedDict()
for alg, path in alg2path.items():
alg2data[alg], ok_num = [], 0
for seed in seeds:
xpath = path.format(seed)
if os.path.isfile(xpath):
ok_num += 1
else:
print("This is an invalid path : {:}".format(xpath))
continue
data = torch.load(xpath, map_location=torch.device("cpu"))
data = torch.load(data["last_checkpoint"], map_location=torch.device("cpu"))
alg2data[alg].append(data["genotypes"])
print("This algorithm : {:} has {:} valid ckps.".format(alg, ok_num))
assert ok_num > 0, "Must have at least 1 valid ckps."
return alg2data
y_min_s = {('cifar10', 'tss'): 90,
('cifar10', 'sss'): 92,
('cifar100', 'tss'): 65,
('cifar100', 'sss'): 65,
('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40}
y_min_s = {
("cifar10", "tss"): 90,
("cifar10", "sss"): 92,
("cifar100", "tss"): 65,
("cifar100", "sss"): 65,
("ImageNet16-120", "tss"): 36,
("ImageNet16-120", "sss"): 40,
}
y_max_s = {('cifar10', 'tss'): 94.5,
('cifar10', 'sss'): 93.3,
('cifar100', 'tss'): 72,
('cifar100', 'sss'): 70,
('ImageNet16-120', 'tss'): 44,
('ImageNet16-120', 'sss'): 46}
y_max_s = {
("cifar10", "tss"): 94.5,
("cifar10", "sss"): 93.3,
("cifar100", "tss"): 72,
("cifar100", "sss"): 70,
("ImageNet16-120", "tss"): 44,
("ImageNet16-120", "sss"): 46,
}
name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"}
name2label = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet-16-120'}
def visualize_curve(api, vis_save_dir, search_space):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
epochs = 100
colors = ['b', 'g', 'c', 'm', 'y', 'r']
ax.set_xlim(0, epochs)
# ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
print('plot alg : {:}'.format(alg))
xs, accuracies = [], []
for iepoch in range(epochs + 1):
try:
structures, accs = [_[iepoch-1] for _ in data], []
except:
raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset))
for structure in structures:
info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False)
accs.append(info['test-accuracy'])
accuracies.append(sum(accs)/len(accs))
xs.append(iepoch)
alg2accuracies[alg] = accuracies
ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg))
ax.set_xlabel('The searching epoch', fontsize=LabelSize)
ax.set_ylabel('Test accuracy on {:}'.format(name2label[dataset]), fontsize=LabelSize)
ax.set_title('Searching results on {:}'.format(name2label[dataset]), fontsize=LabelSize+4)
ax.legend(loc=4, fontsize=LegendFontsize)
def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
epochs = 100
colors = ["b", "g", "c", "m", "y", "r"]
ax.set_xlim(0, epochs)
# ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
print("plot alg : {:}".format(alg))
xs, accuracies = [], []
for iepoch in range(epochs + 1):
try:
structures, accs = [_[iepoch - 1] for _ in data], []
except:
raise ValueError("This alg {:} on {:} has invalid checkpoints.".format(alg, dataset))
for structure in structures:
info = api.get_more_info(
structure, dataset=dataset, hp=90 if api.search_space_name == "size" else 200, is_random=False
)
accs.append(info["test-accuracy"])
accuracies.append(sum(accs) / len(accs))
xs.append(iepoch)
alg2accuracies[alg] = accuracies
ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg))
ax.set_xlabel("The searching epoch", fontsize=LabelSize)
ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize)
ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print('sub-plot {:} on {:} done.'.format(dataset, search_space))
save_path = (vis_save_dir / '{:}-ws-curve.png'.format(search_space)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print("sub-plot {:} on {:} done.".format(dataset, search_space))
save_path = (vis_save_dir / "{:}-ws-curve.png".format(search_space)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.')
parser.add_argument('--search_space', type=str, default='tss', choices=['tss', 'sss'], help='Choose the search space.')
args = parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log."
)
parser.add_argument(
"--search_space", type=str, default="tss", choices=["tss", "sss"], help="Choose the search space."
)
args = parser.parse_args()
save_dir = Path(args.save_dir)
save_dir = Path(args.save_dir)
api = create(None, args.search_space, fast_mode=True, verbose=False)
visualize_curve(api, save_dir, args.search_space)
api = create(None, args.search_space, fast_mode=True, verbose=False)
visualize_curve(api, save_dir, args.search_space)

View File

@@ -10,16 +10,18 @@ import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from log_utils import time_string
from models import get_cell_based_tiny_net
@@ -27,382 +29,577 @@ from nats_bench import create
def visualize_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print ('{:} start to visualize relative ranking'.format(time_string()))
print("{:} start to visualize relative ranking".format(time_string()))
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
print ('{:} prepare data done.'.format(time_string()))
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append(cifar100_ord_indexes.index(idx))
imagenet_labels.append(imagenet_ord_indexes.index(idx))
print("{:} prepare data done.".format(time_string()))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical")
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10")
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100")
ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120")
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
def visualize_sss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / '{:}-cache-sss-info.pth'.format(dataset)
if not cache_file_path.exists():
print ('Do not find cache file : {:}'.format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp='90')
params.append(cost_info['params'])
flops.append(cost_info['flops'])
# accuracy
info = api.get_more_info(index, dataset, hp='90', is_random=False)
train_accs.append(info['train-accuracy'])
test_accs.append(info['test-accuracy'])
if dataset == 'cifar10':
info = api.get_more_info(index, 'cifar10-valid', hp='90', is_random=False)
valid_accs.append(info['valid-accuracy'])
else:
valid_accs.append(info['valid-accuracy'])
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs}
torch.save(info, cache_file_path)
else:
print ('Find cache file : {:}'.format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs']
print ('{:} collect data done.'.format(time_string()))
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} information".format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset)
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp="90")
params.append(cost_info["params"])
flops.append(cost_info["flops"])
# accuracy
info = api.get_more_info(index, dataset, hp="90", is_random=False)
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
info = {
"params": params,
"flops": flops,
"train_accs": train_accs,
"valid_accs": valid_accs,
"test_accs": test_accs,
}
torch.save(info, cache_file_path)
else:
print("Find cache file : {:}".format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = (
info["params"],
info["flops"],
info["train_accs"],
info["valid_accs"],
info["test_accs"],
)
print("{:} collect data done.".format(time_string()))
pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64']
pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid]
largest_indexes = [api.query_index_by_arch('64:64:64:64:64')]
pyramid = [
"8:16:32:48:64",
"8:8:16:32:48",
"8:8:16:16:32",
"8:8:16:16:48",
"8:8:16:16:64",
"16:16:32:32:64",
"32:32:64:64:64",
]
pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid]
largest_indexes = [api.query_index_by_arch("64:64:64:64:64")]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
# ax1, ax2, ax3, ax4, ax5 = axs
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax2, ax3, ax4, ax5 = axs
# ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5))
# ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
# ax1.set_xlabel('architecture ID', fontsize=LabelSize)
# ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize)
fig, axs = plt.subplots(1, 4, figsize=figsize)
# ax1, ax2, ax3, ax4, ax5 = axs
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax2, ax3, ax4, ax5 = axs
# ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5))
# ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
# ax1.set_xlabel('architecture ID', fontsize=LabelSize)
# ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
ax2.scatter([params[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
ax2.scatter(
[params[x] for x in pyramid_indexes],
[train_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax2.scatter(
[params[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
ax3.scatter([params[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
ax3.scatter(
[params[x] for x in pyramid_indexes],
[test_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax3.scatter(
[params[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue')
ax4.scatter([flops[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
ax4.scatter(
[flops[x] for x in pyramid_indexes],
[train_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax4.scatter(
[flops[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue')
ax5.scatter([flops[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax5.legend(loc=4, fontsize=LegendFontsize)
ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
ax5.scatter(
[flops[x] for x in pyramid_indexes],
[test_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax5.scatter(
[flops[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax5.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / 'sss-{:}.png'.format(dataset)
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
save_path = vis_save_dir / "sss-{:}.png".format(dataset)
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def visualize_tss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / '{:}-cache-tss-info.pth'.format(dataset)
if not cache_file_path.exists():
print ('Do not find cache file : {:}'.format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp='12')
params.append(cost_info['params'])
flops.append(cost_info['flops'])
# accuracy
info = api.get_more_info(index, dataset, hp='200', is_random=False)
train_accs.append(info['train-accuracy'])
test_accs.append(info['test-accuracy'])
if dataset == 'cifar10':
info = api.get_more_info(index, 'cifar10-valid', hp='200', is_random=False)
valid_accs.append(info['valid-accuracy'])
else:
valid_accs.append(info['valid-accuracy'])
print('')
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs}
torch.save(info, cache_file_path)
else:
print ('Find cache file : {:}'.format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs']
print ('{:} collect data done.'.format(time_string()))
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} information".format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset)
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp="12")
params.append(cost_info["params"])
flops.append(cost_info["flops"])
# accuracy
info = api.get_more_info(index, dataset, hp="200", is_random=False)
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
print("")
info = {
"params": params,
"flops": flops,
"train_accs": train_accs,
"valid_accs": valid_accs,
"test_accs": test_accs,
}
torch.save(info, cache_file_path)
else:
print("Find cache file : {:}".format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = (
info["params"],
info["flops"],
info["train_accs"],
info["valid_accs"],
info["test_accs"],
)
print("{:} collect data done.".format(time_string()))
resnet = ['|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|']
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
largest_indexes = [api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|')]
resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"]
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
largest_indexes = [
api.query_index_by_arch(
"|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|"
)
]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
# ax1, ax2, ax3, ax4, ax5 = axs
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax2, ax3, ax4, ax5 = axs
# ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5))
# ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
# ax1.set_xlabel('architecture ID', fontsize=LabelSize)
# ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize)
fig, axs = plt.subplots(1, 4, figsize=figsize)
# ax1, ax2, ax3, ax4, ax5 = axs
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax2, ax3, ax4, ax5 = axs
# ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5))
# ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
# ax1.set_xlabel('architecture ID', fontsize=LabelSize)
# ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
ax2.scatter([params[x] for x in resnet_indexes] , [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
ax2.scatter(
[params[x] for x in resnet_indexes],
[train_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax2.scatter(
[params[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
ax3.scatter([params[x] for x in resnet_indexes] , [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
ax3.scatter(
[params[x] for x in resnet_indexes],
[test_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax3.scatter(
[params[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue')
ax4.scatter([flops[x] for x in resnet_indexes], [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
ax4.scatter(
[flops[x] for x in resnet_indexes],
[train_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax4.scatter(
[flops[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue')
ax5.scatter([flops[x] for x in resnet_indexes], [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax5.legend(loc=4, fontsize=LegendFontsize)
ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
ax5.scatter(
[flops[x] for x in resnet_indexes],
[test_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax5.scatter(
[flops[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax5.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / 'tss-{:}.png'.format(dataset)
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
save_path = vis_save_dir / "tss-{:}.png".format(dataset)
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def visualize_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print ('{:} start to visualize relative ranking'.format(time_string()))
print("{:} start to visualize relative ranking".format(time_string()))
dpi, width, height = 250, 3800, 1200
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
dpi, width, height = 250, 3800, 1200
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 3, figsize=figsize)
ax1, ax2, ax3 = axs
fig, axs = plt.subplots(1, 3, figsize=figsize)
ax1, ax2, ax3 = axs
def get_labels(info):
ord_test_indexes = sorted(indexes, key=lambda i: info['test_accs'][i])
ord_valid_indexes = sorted(indexes, key=lambda i: info['valid_accs'][i])
labels = []
for idx in ord_test_indexes:
labels.append(ord_valid_indexes.index(idx))
return labels
def get_labels(info):
ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i])
ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i])
labels = []
for idx in ord_test_indexes:
labels.append(ord_valid_indexes.index(idx))
return labels
def plot_ax(labels, ax, name):
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
tick.label.set_rotation(90)
ax.set_xlim(min(indexes), max(indexes))
ax.set_ylim(min(indexes), max(indexes))
ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//3))
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//5))
ax.scatter(indexes, labels , marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green' , label='{:} test'.format(name))
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='{:} validation'.format(name))
ax.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('ranking on the {:} validation'.format(name), fontsize=LabelSize)
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
labels = get_labels(cifar010_info)
plot_ax(labels, ax1, 'CIFAR-10')
labels = get_labels(cifar100_info)
plot_ax(labels, ax2, 'CIFAR-100')
labels = get_labels(imagenet_info)
plot_ax(labels, ax3, 'ImageNet-16-120')
def plot_ax(labels, ax, name):
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
tick.label.set_rotation(90)
ax.set_xlim(min(indexes), max(indexes))
ax.set_ylim(min(indexes), max(indexes))
ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3))
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name))
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name))
ax.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-same-relative-rank.pdf'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-same-relative-rank.png'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
labels = get_labels(cifar010_info)
plot_ax(labels, ax1, "CIFAR-10")
labels = get_labels(cifar100_info)
plot_ax(labels, ax2, "CIFAR-100")
labels = get_labels(imagenet_info)
plot_ax(labels, ax3, "ImageNet-16-120")
save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def calculate_correlation(*vectors):
matrix = []
for i, vectori in enumerate(vectors):
x = []
for j, vectorj in enumerate(vectors):
x.append( np.corrcoef(vectori, vectorj)[0,1] )
matrix.append( x )
return np.array(matrix)
matrix = []
for i, vectori in enumerate(vectors):
x = []
for j, vectorj in enumerate(vectors):
x.append(np.corrcoef(vectori, vectorj)[0, 1])
matrix.append(x)
return np.array(matrix)
def visualize_all_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print ('{:} start to visualize relative ranking'.format(time_string()))
print("{:} start to visualize relative ranking".format(time_string()))
dpi, width, height = 250, 3200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
dpi, width, height = 250, 3200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 2, figsize=figsize)
ax1, ax2 = axs
fig, axs = plt.subplots(1, 2, figsize=figsize)
ax1, ax2 = axs
sns_size = 15
CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs'])
sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5, ax=ax1,
xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'],
yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'])
selected_indexes, acc_bar = [], 92
for i, acc in enumerate(cifar010_info['test_accs']):
if acc > acc_bar: selected_indexes.append( i )
cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ]
cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ]
cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ]
cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ]
imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ]
imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ]
CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)
sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5, ax=ax2,
xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'],
yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'])
ax1.set_title('Correlation coefficient over ALL candidates')
ax2.set_title('Correlation coefficient over candidates with accuracy > {:}%'.format(acc_bar))
save_path = (vis_save_dir / '{:}-all-relative-rank.png'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
sns_size = 15
CoRelMatrix = calculate_correlation(
cifar010_info["valid_accs"],
cifar010_info["test_accs"],
cifar100_info["valid_accs"],
cifar100_info["test_accs"],
imagenet_info["valid_accs"],
imagenet_info["test_accs"],
)
sns.heatmap(
CoRelMatrix,
annot=True,
annot_kws={"size": sns_size},
fmt=".3f",
linewidths=0.5,
ax=ax1,
xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
selected_indexes, acc_bar = [], 92
for i, acc in enumerate(cifar010_info["test_accs"]):
if acc > acc_bar:
selected_indexes.append(i)
cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes]
cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes]
cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes]
cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes]
imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes]
imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes]
CoRelMatrix = calculate_correlation(
cifar010_valid_accs,
cifar010_test_accs,
cifar100_valid_accs,
cifar100_test_accs,
imagenet_valid_accs,
imagenet_test_accs,
)
sns.heatmap(
CoRelMatrix,
annot=True,
annot_kws={"size": sns_size},
fmt=".3f",
linewidths=0.5,
ax=ax2,
xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
ax1.set_title("Correlation coefficient over ALL candidates")
ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar))
save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench', help='Folder to save checkpoints and log.')
# use for train the model
args = parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log."
)
# use for train the model
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
to_save_dir = Path(args.save_dir)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
api201 = create(None, 'tss', verbose=True)
for xdata in datasets:
visualize_tss_info(api201, xdata, to_save_dir)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
api201 = create(None, "tss", verbose=True)
for xdata in datasets:
visualize_tss_info(api201, xdata, to_save_dir)
api_sss = create(None, 'size', verbose=True)
for xdata in datasets:
visualize_sss_info(api_sss, xdata, to_save_dir)
api_sss = create(None, "size", verbose=True)
for xdata in datasets:
visualize_sss_info(api_sss, xdata, to_save_dir)
visualize_info(None, to_save_dir, 'tss')
visualize_info(None, to_save_dir, 'sss')
visualize_rank_info(None, to_save_dir, 'tss')
visualize_rank_info(None, to_save_dir, 'sss')
visualize_info(None, to_save_dir, "tss")
visualize_info(None, to_save_dir, "sss")
visualize_rank_info(None, to_save_dir, "tss")
visualize_rank_info(None, to_save_dir, "sss")
visualize_all_rank_info(None, to_save_dir, 'tss')
visualize_all_rank_info(None, to_save_dir, 'sss')
visualize_all_rank_info(None, to_save_dir, "tss")
visualize_all_rank_info(None, to_save_dir, "sss")