Add SuperSimpleNorm and update synthetic env
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
||||
#####################################################
|
||||
# python exps/synthetic/baseline.py #
|
||||
#####################################################
|
||||
import os, sys, copy
|
||||
############################################################################
|
||||
# CUDA_VISIBLE_DEVICES=0 python exps/synthetic/baseline.py #
|
||||
############################################################################
|
||||
import os, sys, copy, random
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
@@ -28,6 +28,8 @@ from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv
|
||||
from datasets import DynamicQuadraticFunc
|
||||
from datasets.synthetic_example import create_example_v1
|
||||
|
||||
from utils.temp_sync import optimize_fn, evaluate_fn
|
||||
|
||||
|
||||
def draw_fig(save_dir, timestamp, scatter_list):
|
||||
save_path = save_dir / "{:04d}".format(timestamp)
|
||||
@@ -67,28 +69,55 @@ def draw_fig(save_dir, timestamp, scatter_list):
|
||||
def main(save_dir):
|
||||
save_dir = Path(str(save_dir))
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
dynamic_env, function = create_example_v1(100, num_per_task=500)
|
||||
dynamic_env, function = create_example_v1(100, num_per_task=1000)
|
||||
|
||||
additional_xaxis = np.arange(-6, 6, 0.1)
|
||||
for timestamp, dataset in tqdm(dynamic_env, ncols=50):
|
||||
num = dataset.shape[0]
|
||||
xaxis = dataset[:, 0].numpy()
|
||||
additional_xaxis = np.arange(-6, 6, 0.2)
|
||||
models = dict()
|
||||
|
||||
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
xaxis_all = dataset[:, 0].numpy()
|
||||
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
|
||||
# compute the ground truth
|
||||
function.set_timestamp(timestamp)
|
||||
yaxis = function(xaxis)
|
||||
# xaxis = np.concatenate((additional_xaxis, xaxis))
|
||||
yaxis_all = function.noise_call(xaxis_all)
|
||||
|
||||
# split the dataset
|
||||
indexes = list(range(xaxis_all.shape[0]))
|
||||
random.shuffle(indexes)
|
||||
train_indexes = indexes[:len(indexes)//2]
|
||||
valid_indexes = indexes[len(indexes)//2:]
|
||||
train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes]
|
||||
valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes]
|
||||
|
||||
model, loss_fn, train_loss = optimize_fn(train_xs, train_ys)
|
||||
# model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all)
|
||||
pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn)
|
||||
print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss))
|
||||
|
||||
# the first plot
|
||||
scatter_list = []
|
||||
scatter_list.append(
|
||||
{
|
||||
"xaxis": xaxis,
|
||||
"yaxis": yaxis,
|
||||
"xaxis": valid_xs,
|
||||
"yaxis": valid_ys,
|
||||
"color": "k",
|
||||
"s": 10,
|
||||
"alpha": 0.99,
|
||||
"label": "Timestamp={:02d}".format(timestamp),
|
||||
}
|
||||
)
|
||||
|
||||
scatter_list.append(
|
||||
{
|
||||
"xaxis": valid_xs,
|
||||
"yaxis": pred_valid_ys,
|
||||
"color": "r",
|
||||
"s": 10,
|
||||
"alpha": 0.5,
|
||||
"label": "MLP at now"
|
||||
}
|
||||
)
|
||||
|
||||
draw_fig(save_dir, timestamp, scatter_list)
|
||||
print("Save all figures into {:}".format(save_dir))
|
||||
save_dir = save_dir.resolve()
|
||||
|
Reference in New Issue
Block a user