Update GeMOSA v4
This commit is contained in:
@@ -33,7 +33,9 @@ from xautodl.procedures.metric_utils import MSEMetric
|
||||
|
||||
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
|
||||
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label)
|
||||
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None)
|
||||
cur_ax.scatter(
|
||||
xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None
|
||||
)
|
||||
|
||||
|
||||
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
|
||||
@@ -193,16 +195,28 @@ def visualize_env(save_dir, version):
|
||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
allxs.append(allx)
|
||||
allys.append(ally)
|
||||
if dynamic_env.meta_info['task'] == 'regression':
|
||||
if dynamic_env.meta_info["task"] == "regression":
|
||||
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
||||
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
|
||||
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
|
||||
elif dynamic_env.meta_info['task'] == 'classification':
|
||||
print(
|
||||
"x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())
|
||||
)
|
||||
print(
|
||||
"y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())
|
||||
)
|
||||
elif dynamic_env.meta_info["task"] == "classification":
|
||||
allxs = torch.cat(allxs)
|
||||
print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item()))
|
||||
print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item()))
|
||||
print(
|
||||
"x[0] - min={:.3f}, max={:.3f}".format(
|
||||
allxs[:, 0].min().item(), allxs[:, 0].max().item()
|
||||
)
|
||||
)
|
||||
print(
|
||||
"x[1] - min={:.3f}, max={:.3f}".format(
|
||||
allxs[:, 1].min().item(), allxs[:, 1].max().item()
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
|
||||
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
|
||||
|
||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
dpi, width, height = 30, 1800, 1400
|
||||
@@ -211,29 +225,51 @@ def visualize_env(save_dir, version):
|
||||
fig = plt.figure(figsize=figsize)
|
||||
|
||||
cur_ax = fig.add_subplot(1, 1, 1)
|
||||
if dynamic_env.meta_info['task'] == 'regression':
|
||||
if dynamic_env.meta_info["task"] == "regression":
|
||||
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
|
||||
plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx))
|
||||
plot_scatter(
|
||||
cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)
|
||||
)
|
||||
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
|
||||
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
|
||||
elif dynamic_env.meta_info['task'] == 'classification':
|
||||
elif dynamic_env.meta_info["task"] == "classification":
|
||||
positive, negative = ally == 1, ally == 0
|
||||
# plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx))
|
||||
plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive")
|
||||
plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative")
|
||||
cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1))
|
||||
cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1))
|
||||
plot_scatter(
|
||||
cur_ax,
|
||||
allx[positive, 0],
|
||||
allx[positive, 1],
|
||||
"r",
|
||||
0.99,
|
||||
(20, 10),
|
||||
"positive",
|
||||
)
|
||||
plot_scatter(
|
||||
cur_ax,
|
||||
allx[negative, 0],
|
||||
allx[negative, 1],
|
||||
"g",
|
||||
0.99,
|
||||
(20, 10),
|
||||
"negative",
|
||||
)
|
||||
cur_ax.set_xlim(
|
||||
round(allxs[:, 0].min().item(), 1), round(allxs[:, 0].max().item(), 1)
|
||||
)
|
||||
cur_ax.set_ylim(
|
||||
round(allxs[:, 1].min().item(), 1), round(allxs[:, 1].max().item(), 1)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
|
||||
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
|
||||
|
||||
cur_ax.set_xlabel("X", fontsize=LabelSize)
|
||||
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
|
||||
for tick in cur_ax.xaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
tick.label.set_rotation(10)
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
tick.label.set_rotation(10)
|
||||
for tick in cur_ax.yaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||
pdf_save_path = (
|
||||
save_dir
|
||||
/ "pdf-{:}".format(version)
|
||||
|
Reference in New Issue
Block a user