Add SuperSimpleNorm and update synthetic env

This commit is contained in:
D-X-Y
2021-04-23 02:12:11 -07:00
parent a5b7d986b3
commit 9b895bdf2e
13 changed files with 238 additions and 519 deletions

View File

@@ -1,9 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
# python exps/synthetic/baseline.py #
#####################################################
import os, sys, copy
############################################################################
# CUDA_VISIBLE_DEVICES=0 python exps/synthetic/baseline.py #
############################################################################
import os, sys, copy, random
import torch
import numpy as np
import argparse
@@ -28,6 +28,8 @@ from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv
from datasets import DynamicQuadraticFunc
from datasets.synthetic_example import create_example_v1
from utils.temp_sync import optimize_fn, evaluate_fn
def draw_fig(save_dir, timestamp, scatter_list):
save_path = save_dir / "{:04d}".format(timestamp)
@@ -67,28 +69,55 @@ def draw_fig(save_dir, timestamp, scatter_list):
def main(save_dir):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env, function = create_example_v1(100, num_per_task=500)
dynamic_env, function = create_example_v1(100, num_per_task=1000)
additional_xaxis = np.arange(-6, 6, 0.1)
for timestamp, dataset in tqdm(dynamic_env, ncols=50):
num = dataset.shape[0]
xaxis = dataset[:, 0].numpy()
additional_xaxis = np.arange(-6, 6, 0.2)
models = dict()
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
xaxis_all = dataset[:, 0].numpy()
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
# compute the ground truth
function.set_timestamp(timestamp)
yaxis = function(xaxis)
# xaxis = np.concatenate((additional_xaxis, xaxis))
yaxis_all = function.noise_call(xaxis_all)
# split the dataset
indexes = list(range(xaxis_all.shape[0]))
random.shuffle(indexes)
train_indexes = indexes[:len(indexes)//2]
valid_indexes = indexes[len(indexes)//2:]
train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes]
valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes]
model, loss_fn, train_loss = optimize_fn(train_xs, train_ys)
# model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all)
pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn)
print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss))
# the first plot
scatter_list = []
scatter_list.append(
{
"xaxis": xaxis,
"yaxis": yaxis,
"xaxis": valid_xs,
"yaxis": valid_ys,
"color": "k",
"s": 10,
"alpha": 0.99,
"label": "Timestamp={:02d}".format(timestamp),
}
)
scatter_list.append(
{
"xaxis": valid_xs,
"yaxis": pred_valid_ys,
"color": "r",
"s": 10,
"alpha": 0.5,
"label": "MLP at now"
}
)
draw_fig(save_dir, timestamp, scatter_list)
print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve()