Update GeMOSA v4

This commit is contained in:
D-X-Y
2021-05-27 19:27:29 +08:00
parent 16861f0f3d
commit 08337138f1
3 changed files with 130 additions and 39 deletions

View File

@@ -33,7 +33,9 @@ from xautodl.procedures.metric_utils import MSEMetric
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label)
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None)
cur_ax.scatter(
xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None
)
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
@@ -193,16 +195,28 @@ def visualize_env(save_dir, version):
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
if dynamic_env.meta_info['task'] == 'regression':
if dynamic_env.meta_info["task"] == "regression":
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
elif dynamic_env.meta_info['task'] == 'classification':
print(
"x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())
)
print(
"y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())
)
elif dynamic_env.meta_info["task"] == "classification":
allxs = torch.cat(allxs)
print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item()))
print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item()))
print(
"x[0] - min={:.3f}, max={:.3f}".format(
allxs[:, 0].min().item(), allxs[:, 0].max().item()
)
)
print(
"x[1] - min={:.3f}, max={:.3f}".format(
allxs[:, 1].min().item(), allxs[:, 1].max().item()
)
)
else:
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400
@@ -211,29 +225,51 @@ def visualize_env(save_dir, version):
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(1, 1, 1)
if dynamic_env.meta_info['task'] == 'regression':
if dynamic_env.meta_info["task"] == "regression":
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx))
plot_scatter(
cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)
)
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
elif dynamic_env.meta_info['task'] == 'classification':
elif dynamic_env.meta_info["task"] == "classification":
positive, negative = ally == 1, ally == 0
# plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx))
plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive")
plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative")
cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1))
cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1))
plot_scatter(
cur_ax,
allx[positive, 0],
allx[positive, 1],
"r",
0.99,
(20, 10),
"positive",
)
plot_scatter(
cur_ax,
allx[negative, 0],
allx[negative, 1],
"g",
0.99,
(20, 10),
"negative",
)
cur_ax.set_xlim(
round(allxs[:, 0].min().item(), 1), round(allxs[:, 0].max().item(), 1)
)
cur_ax.set_ylim(
round(allxs[:, 1].min().item(), 1), round(allxs[:, 1].max().item(), 1)
)
else:
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = (
save_dir
/ "pdf-{:}".format(version)