add naswot
This commit is contained in:
132
graph_dit/naswot/pycls/core/plotting.py
Normal file
132
graph_dit/naswot/pycls/core/plotting.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Plotting functions."""
|
||||
|
||||
import colorlover as cl
|
||||
import matplotlib.pyplot as plt
|
||||
import plotly.graph_objs as go
|
||||
import plotly.offline as offline
|
||||
import pycls.core.logging as logging
|
||||
|
||||
|
||||
def get_plot_colors(max_colors, color_format="pyplot"):
|
||||
"""Generate colors for plotting."""
|
||||
colors = cl.scales["11"]["qual"]["Paired"]
|
||||
if max_colors > len(colors):
|
||||
colors = cl.to_rgb(cl.interp(colors, max_colors))
|
||||
if color_format == "pyplot":
|
||||
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
|
||||
return colors
|
||||
|
||||
|
||||
def prepare_plot_data(log_files, names, metric="top1_err"):
|
||||
"""Load logs and extract data for plotting error curves."""
|
||||
plot_data = []
|
||||
for file, name in zip(log_files, names):
|
||||
d, data = {}, logging.sort_log_data(logging.load_log_data(file))
|
||||
for phase in ["train", "test"]:
|
||||
x = data[phase + "_epoch"]["epoch_ind"]
|
||||
y = data[phase + "_epoch"][metric]
|
||||
d["x_" + phase], d["y_" + phase] = x, y
|
||||
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
|
||||
plot_data.append(d)
|
||||
assert len(plot_data) > 0, "No data to plot"
|
||||
return plot_data
|
||||
|
||||
|
||||
def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
|
||||
"""Plot error curves using plotly and save to file."""
|
||||
plot_data = prepare_plot_data(log_files, names, metric)
|
||||
colors = get_plot_colors(len(plot_data), "plotly")
|
||||
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
|
||||
data = []
|
||||
for i, d in enumerate(plot_data):
|
||||
s = str(i)
|
||||
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
|
||||
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_train"],
|
||||
y=d["y_train"],
|
||||
mode="lines",
|
||||
name=d["train_label"],
|
||||
line=line_train,
|
||||
legendgroup=s,
|
||||
visible=True,
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_test"],
|
||||
y=d["y_test"],
|
||||
mode="lines",
|
||||
name=d["test_label"],
|
||||
line=line_test,
|
||||
legendgroup=s,
|
||||
visible=True,
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_train"],
|
||||
y=d["y_train"],
|
||||
mode="lines",
|
||||
name=d["train_label"],
|
||||
line=line_train,
|
||||
legendgroup=s,
|
||||
visible=False,
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
# Prepare layout w ability to toggle 'all', 'train', 'test'
|
||||
titlefont = {"size": 18, "color": "#7f7f7f"}
|
||||
vis = [[True, True, False], [False, False, True], [False, True, False]]
|
||||
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
|
||||
buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
|
||||
layout = go.Layout(
|
||||
title=metric + " vs. epoch<br>[dash=train, solid=test]",
|
||||
xaxis={"title": "epoch", "titlefont": titlefont},
|
||||
yaxis={"title": metric, "titlefont": titlefont},
|
||||
showlegend=True,
|
||||
hoverlabel={"namelength": -1},
|
||||
updatemenus=[
|
||||
{
|
||||
"buttons": buttons,
|
||||
"direction": "down",
|
||||
"showactive": True,
|
||||
"x": 1.02,
|
||||
"xanchor": "left",
|
||||
"y": 1.08,
|
||||
"yanchor": "top",
|
||||
}
|
||||
],
|
||||
)
|
||||
# Create plotly plot
|
||||
offline.plot({"data": data, "layout": layout}, filename=filename)
|
||||
|
||||
|
||||
def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
|
||||
"""Plot error curves using matplotlib.pyplot and save to file."""
|
||||
plot_data = prepare_plot_data(log_files, names, metric)
|
||||
colors = get_plot_colors(len(names))
|
||||
for ind, d in enumerate(plot_data):
|
||||
c, lbl = colors[ind], d["test_label"]
|
||||
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
|
||||
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
|
||||
plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
|
||||
plt.xlabel("epoch", fontsize=14)
|
||||
plt.ylabel(metric, fontsize=14)
|
||||
plt.grid(alpha=0.4)
|
||||
plt.legend()
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
plt.clf()
|
||||
else:
|
||||
plt.show()
|
Reference in New Issue
Block a user