Update GeMOSA v4

This commit is contained in:
D-X-Y
2021-05-27 17:30:44 +08:00
parent 1ce0b80776
commit b6e11c6360
8 changed files with 147 additions and 39 deletions

View File

@@ -2,9 +2,9 @@
# Learning to Generate Model One Step Ahead #
#####################################################
# python exps/GeMOSA/main.py --env_version v1 --workers 0
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256
# python exps/GeMOSA/main.py --env_version v2 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256
# python exps/GeMOSA/main.py --env_version v3 --device cuda --lr 0.002 --hidden_dim 32 --meta_batch 256
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm

View File

@@ -3,7 +3,8 @@
############################################################################
# python exps/GeMOSA/vis-synthetic.py --env_version v1 #
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
# python exps/GeMOSA/vis-synthetic.py --env_version v3 #
# python exps/GeMOSA/vis-synthetic.py --env_version v4 #
############################################################################
import os, sys, copy, random
import torch
@@ -31,8 +32,8 @@ from xautodl.procedures.metric_utils import MSEMetric
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths, label=label)
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=1.5, label=None)
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label)
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None)
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
@@ -186,15 +187,23 @@ def visualize_env(save_dir, version):
sub_save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env = get_synthetic_env(version=version)
print("env: {:}".format(dynamic_env))
print("oracle_map: {:}".format(dynamic_env.oracle_map))
allxs, allys = [], []
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
print("env: {:}".format(dynamic_env))
print("oracle_map: {:}".format(dynamic_env.oracle_map))
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
if dynamic_env.meta_info['task'] == 'regression':
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
elif dynamic_env.meta_info['task'] == 'classification':
allxs = torch.cat(allxs)
print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item()))
print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item()))
else:
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi)
@@ -202,19 +211,29 @@ def visualize_env(save_dir, version):
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(1, 1, 1)
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
plot_scatter(cur_ax, allx, ally, "k", 0.99, 15, "timestamp={:05d}".format(idx))
if dynamic_env.meta_info['task'] == 'regression':
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx))
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
elif dynamic_env.meta_info['task'] == 'classification':
positive, negative = ally == 1, ally == 0
# plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx))
plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive")
plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative")
cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1))
cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1))
else:
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
cur_ax.legend(loc=1, fontsize=LegendFontsize)
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = (
save_dir
/ "pdf-{:}".format(version)
@@ -237,7 +256,7 @@ def visualize_env(save_dir, version):
os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version))
def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
def compare_algs(save_dir, version, alg_dir="./outputs/GeMOSA-synthetic"):
save_dir = Path(str(save_dir))
for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr