Update GeMOSA v4
This commit is contained in:
@@ -2,9 +2,9 @@
|
||||
# Learning to Generate Model One Step Ahead #
|
||||
#####################################################
|
||||
# python exps/GeMOSA/main.py --env_version v1 --workers 0
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256
|
||||
# python exps/GeMOSA/main.py --env_version v2 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256
|
||||
# python exps/GeMOSA/main.py --env_version v3 --device cuda --lr 0.002 --hidden_dim 32 --meta_batch 256
|
||||
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
|
@@ -3,7 +3,8 @@
|
||||
############################################################################
|
||||
# python exps/GeMOSA/vis-synthetic.py --env_version v1 #
|
||||
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
|
||||
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
|
||||
# python exps/GeMOSA/vis-synthetic.py --env_version v3 #
|
||||
# python exps/GeMOSA/vis-synthetic.py --env_version v4 #
|
||||
############################################################################
|
||||
import os, sys, copy, random
|
||||
import torch
|
||||
@@ -31,8 +32,8 @@ from xautodl.procedures.metric_utils import MSEMetric
|
||||
|
||||
|
||||
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
|
||||
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths, label=label)
|
||||
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=1.5, label=None)
|
||||
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label)
|
||||
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None)
|
||||
|
||||
|
||||
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
|
||||
@@ -186,15 +187,23 @@ def visualize_env(save_dir, version):
|
||||
sub_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dynamic_env = get_synthetic_env(version=version)
|
||||
print("env: {:}".format(dynamic_env))
|
||||
print("oracle_map: {:}".format(dynamic_env.oracle_map))
|
||||
allxs, allys = [], []
|
||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
allxs.append(allx)
|
||||
allys.append(ally)
|
||||
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
||||
print("env: {:}".format(dynamic_env))
|
||||
print("oracle_map: {:}".format(dynamic_env.oracle_map))
|
||||
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
|
||||
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
|
||||
if dynamic_env.meta_info['task'] == 'regression':
|
||||
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
||||
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
|
||||
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
|
||||
elif dynamic_env.meta_info['task'] == 'classification':
|
||||
allxs = torch.cat(allxs)
|
||||
print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item()))
|
||||
print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item()))
|
||||
else:
|
||||
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
|
||||
|
||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
dpi, width, height = 30, 1800, 1400
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
@@ -202,19 +211,29 @@ def visualize_env(save_dir, version):
|
||||
fig = plt.figure(figsize=figsize)
|
||||
|
||||
cur_ax = fig.add_subplot(1, 1, 1)
|
||||
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
|
||||
plot_scatter(cur_ax, allx, ally, "k", 0.99, 15, "timestamp={:05d}".format(idx))
|
||||
if dynamic_env.meta_info['task'] == 'regression':
|
||||
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
|
||||
plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx))
|
||||
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
|
||||
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
|
||||
elif dynamic_env.meta_info['task'] == 'classification':
|
||||
positive, negative = ally == 1, ally == 0
|
||||
# plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx))
|
||||
plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive")
|
||||
plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative")
|
||||
cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1))
|
||||
cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1))
|
||||
else:
|
||||
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
|
||||
|
||||
cur_ax.set_xlabel("X", fontsize=LabelSize)
|
||||
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
|
||||
for tick in cur_ax.xaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
tick.label.set_rotation(10)
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
tick.label.set_rotation(10)
|
||||
for tick in cur_ax.yaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
|
||||
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
|
||||
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||
pdf_save_path = (
|
||||
save_dir
|
||||
/ "pdf-{:}".format(version)
|
||||
@@ -237,7 +256,7 @@ def visualize_env(save_dir, version):
|
||||
os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version))
|
||||
|
||||
|
||||
def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
|
||||
def compare_algs(save_dir, version, alg_dir="./outputs/GeMOSA-synthetic"):
|
||||
save_dir = Path(str(save_dir))
|
||||
for substr in ("pdf", "png"):
|
||||
sub_save_dir = save_dir / substr
|
||||
|
Reference in New Issue
Block a user