Rerange experimental

This commit is contained in:
D-X-Y
2021-06-03 01:08:17 -07:00
parent d3d950d310
commit 6ee062a33d
22 changed files with 247 additions and 314 deletions

View File

@@ -0,0 +1,319 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve()
print(lib_dir)
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, Top1AccMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core
class MAML:
"""A LFNA meta-model that uses the MLP as delta-net."""
def __init__(
self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1
):
self.criterion = criterion
self.network = network
self.meta_optimizer = torch.optim.Adam(
self.network.parameters(), lr=meta_lr, amsgrad=True
)
self.inner_lr = inner_lr
self.inner_step = inner_step
self._best_info = dict(state_dict=None, iepoch=None, score=None)
print("There are {:} weights.".format(self.network.get_w_container().numel()))
def adapt(self, x, y):
# create a container for the future timestamp
container = self.network.get_w_container()
for k in range(0, self.inner_step):
y_hat = self.network.forward_with_container(x, container)
loss = self.criterion(y_hat, y)
grads = torch.autograd.grad(loss, container.parameters())
container = container.additive([-self.inner_lr * grad for grad in grads])
return container
def predict(self, x, container=None):
if container is not None:
y_hat = self.network.forward_with_container(x, container)
else:
y_hat = self.network(x)
return y_hat
def step(self):
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
self.meta_optimizer.step()
def zero_grad(self):
self.meta_optimizer.zero_grad()
def load_state_dict(self, state_dict):
self.criterion.load_state_dict(state_dict["criterion"])
self.network.load_state_dict(state_dict["network"])
self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"])
def state_dict(self):
state_dict = dict()
state_dict["criterion"] = self.criterion.state_dict()
state_dict["network"] = self.network.state_dict()
state_dict["meta_optimizer"] = self.meta_optimizer.state_dict()
return state_dict
def save_best(self, score):
success, best_score = self.network.save_best(score)
return success, best_score
def load_best(self):
self.network.load_best()
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
test_env = get_synthetic_env(mode="test", version=args.env_version)
all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The trainval enviornment: {:}".format(trainval_env))
logger.log("The total enviornment: {:}".format(all_env))
logger.log("The test enviornment: {:}".format(test_env))
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=all_env.meta_info["input_dim"],
output_dim=all_env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
model = get_model(**model_kwargs)
model = model.to(args.device)
if all_env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif all_env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
maml = MAML(
model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step
)
# meta-training
last_success_epoch = 0
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
head_str = (
"[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs)
+ need_time
)
maml.zero_grad()
meta_losses = []
for ibatch in range(args.meta_batch):
future_idx = random.randint(0, len(trainval_env) - 1)
future_t, (future_x, future_y) = trainval_env[future_idx]
# -->>
seq_times = trainval_env.get_seq_times(future_idx, args.seq_length)
_, (allxs, allys) = trainval_env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if trainval_env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
meta_loss = torch.stack(meta_losses).mean()
meta_loss.backward()
maml.step()
logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item()))
success, best_score = maml.save_best(-meta_loss.item())
if success:
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
save_checkpoint(maml.state_dict(), logger.path("model"), logger)
last_success_epoch = iepoch
if iepoch - last_success_epoch >= args.early_stop_thresh:
logger.log("Early stop at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
# meta-test
maml.load_best()
def finetune(index):
seq_times = test_env.get_seq_times(index, args.seq_length)
_, (allxs, allys) = test_env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if test_env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
historical_y_hat = maml.predict(historical_x, future_container)
train_metric = metric_cls(True)
# model.analyze_weights()
with torch.no_grad():
train_metric(historical_y_hat, historical_y)
train_results = train_metric.get_info()
return train_results, future_container
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(test_env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(test_env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(test_env))
+ " "
+ need_time
)
# build optimizer
train_results, future_container = finetune(idx)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(test_env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], metric.get_info()["score"]
)
)
logger.log(log_str)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the maml.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-maml-ft",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
default=16,
help="The hidden dimension.",
)
parser.add_argument(
"--meta_lr",
type=float,
default=0.02,
help="The learning rate for the MAML optimizer (default is Adam)",
)
parser.add_argument(
"--inner_lr",
type=float,
default=0.005,
help="The learning rate for the inner optimization",
)
parser.add_argument(
"--inner_step", type=int, default=1, help="The inner loop steps for MAML."
)
parser.add_argument(
"--seq_length", type=int, default=20, help="The sequence length."
)
parser.add_argument(
"--meta_batch",
type=int,
default=256,
help="The batch size for the meta-model",
)
parser.add_argument(
"--epochs",
type=int,
default=2000,
help="The total number of epochs.",
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=50,
help="The maximum epochs for early stop.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-s{:}-mlr{:}-d{:}-e{:}-env{:}".format(
args.save_dir,
args.inner_step,
args.meta_lr,
args.hidden_dim,
args.epochs,
args.env_version,
)
main(args)

View File

@@ -0,0 +1,319 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve()
print(lib_dir)
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, Top1AccMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core
class MAML:
"""A LFNA meta-model that uses the MLP as delta-net."""
def __init__(
self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1
):
self.criterion = criterion
self.network = network
self.meta_optimizer = torch.optim.Adam(
self.network.parameters(), lr=meta_lr, amsgrad=True
)
self.inner_lr = inner_lr
self.inner_step = inner_step
self._best_info = dict(state_dict=None, iepoch=None, score=None)
print("There are {:} weights.".format(self.network.get_w_container().numel()))
def adapt(self, x, y):
# create a container for the future timestamp
container = self.network.get_w_container()
for k in range(0, self.inner_step):
y_hat = self.network.forward_with_container(x, container)
loss = self.criterion(y_hat, y)
grads = torch.autograd.grad(loss, container.parameters())
container = container.additive([-self.inner_lr * grad for grad in grads])
return container
def predict(self, x, container=None):
if container is not None:
y_hat = self.network.forward_with_container(x, container)
else:
y_hat = self.network(x)
return y_hat
def step(self):
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
self.meta_optimizer.step()
def zero_grad(self):
self.meta_optimizer.zero_grad()
def load_state_dict(self, state_dict):
self.criterion.load_state_dict(state_dict["criterion"])
self.network.load_state_dict(state_dict["network"])
self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"])
def state_dict(self):
state_dict = dict()
state_dict["criterion"] = self.criterion.state_dict()
state_dict["network"] = self.network.state_dict()
state_dict["meta_optimizer"] = self.meta_optimizer.state_dict()
return state_dict
def save_best(self, score):
success, best_score = self.network.save_best(score)
return success, best_score
def load_best(self):
self.network.load_best()
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
test_env = get_synthetic_env(mode="test", version=args.env_version)
all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The trainval enviornment: {:}".format(trainval_env))
logger.log("The total enviornment: {:}".format(all_env))
logger.log("The test enviornment: {:}".format(test_env))
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=all_env.meta_info["input_dim"],
output_dim=all_env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
model = get_model(**model_kwargs)
model = model.to(args.device)
if all_env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif all_env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
maml = MAML(
model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step
)
# meta-training
last_success_epoch = 0
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
head_str = (
"[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs)
+ need_time
)
maml.zero_grad()
meta_losses = []
for ibatch in range(args.meta_batch):
future_idx = random.randint(0, len(trainval_env) - 1)
future_t, (future_x, future_y) = trainval_env[future_idx]
# -->>
seq_times = trainval_env.get_seq_times(future_idx, args.seq_length)
_, (allxs, allys) = trainval_env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if trainval_env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
meta_loss = torch.stack(meta_losses).mean()
meta_loss.backward()
maml.step()
logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item()))
success, best_score = maml.save_best(-meta_loss.item())
if success:
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
save_checkpoint(maml.state_dict(), logger.path("model"), logger)
last_success_epoch = iepoch
if iepoch - last_success_epoch >= args.early_stop_thresh:
logger.log("Early stop at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
# meta-test
maml.load_best()
def finetune(index):
seq_times = test_env.get_seq_times(index, args.seq_length)
_, (allxs, allys) = test_env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if test_env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
historical_y_hat = maml.predict(historical_x, future_container)
train_metric = metric_cls(True)
# model.analyze_weights()
with torch.no_grad():
train_metric(historical_y_hat, historical_y)
train_results = train_metric.get_info()
return train_results, future_container
train_results, future_container = finetune(0)
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(test_env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(test_env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(test_env))
+ " "
+ need_time
)
# build optimizer
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(test_env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], metric.get_info()["score"]
)
)
logger.log(log_str)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the maml.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-maml-nft",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
default=16,
help="The hidden dimension.",
)
parser.add_argument(
"--meta_lr",
type=float,
default=0.02,
help="The learning rate for the MAML optimizer (default is Adam)",
)
parser.add_argument(
"--inner_lr",
type=float,
default=0.005,
help="The learning rate for the inner optimization",
)
parser.add_argument(
"--inner_step", type=int, default=1, help="The inner loop steps for MAML."
)
parser.add_argument(
"--seq_length", type=int, default=20, help="The sequence length."
)
parser.add_argument(
"--meta_batch",
type=int,
default=256,
help="The batch size for the meta-model",
)
parser.add_argument(
"--epochs",
type=int,
default=2000,
help="The total number of epochs.",
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=50,
help="The maximum epochs for early stop.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-s{:}-mlr{:}-d{:}-e{:}-env{:}".format(
args.save_dir,
args.inner_step,
args.meta_lr,
args.hidden_dim,
args.epochs,
args.env_version,
)
main(args)

View File

@@ -0,0 +1,228 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/slbm-ft.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/baselines/slbm-ft.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/baselines/slbm-ft.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
# python exps/GeMOSA/baselines/slbm-ft.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.procedures.metric_utils import (
SaveMetric,
MSEMetric,
Top1AccMetric,
ComposeMetric,
)
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.utils import show_mean_var
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
env = get_synthetic_env(mode="test", version=args.env_version)
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=env.meta_info["input_dim"],
output_dim=env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
logger.log("The total enviornment: {:}".format(env))
w_containers = dict()
if env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
def finetune(index):
seq_times = env.get_seq_times(index, args.seq_length)
_, (allxs, allys) = env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
model = get_model(**model_kwargs)
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = metric_cls(True)
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
# model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
return train_results, model
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " "
+ need_time
)
# train the same data
train_results, model = finetune(idx)
# build optimizer
xmetric = ComposeMetric(metric_cls(True), SaveMetric())
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = model(future_x)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], metric.get_info()["score"]
)
)
logger.log(log_str)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
return metric.get_info()["score"]
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-same-ft-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--seq_length", type=int, default=20, help="The sequence length."
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format(
args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version
)
if args.rand_seed is None or args.rand_seed < 0:
results = []
for iseed in range(3):
args.rand_seed = random.randint(1, 100000)
result = main(args)
results.append(result)
show_mean_var(results)
else:
main(args)

View File

@@ -0,0 +1,227 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/slbm-nof.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/baselines/slbm-nof.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/baselines/slbm-nof.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
# python exps/GeMOSA/baselines/slbm-nof.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.procedures.metric_utils import (
SaveMetric,
MSEMetric,
Top1AccMetric,
ComposeMetric,
)
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.utils import show_mean_var
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
env = get_synthetic_env(mode="test", version=args.env_version)
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=env.meta_info["input_dim"],
output_dim=env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
logger.log("The total enviornment: {:}".format(env))
w_containers = dict()
if env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
seq_times = env.get_seq_times(0, args.seq_length)
_, (allxs, allys) = env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
model = get_model(**model_kwargs)
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = metric_cls(True)
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
print(train_results)
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " "
+ need_time
)
# train the same data
# build optimizer
xmetric = ComposeMetric(metric_cls(True), SaveMetric())
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = model(future_x)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], metric.get_info()["score"]
)
)
logger.log(log_str)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
return metric.get_info()["score"]
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-same-nof-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--seq_length", type=int, default=20, help="The sequence length."
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format(
args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version
)
if args.rand_seed is None or args.rand_seed < 0:
results = []
for iseed in range(3):
args.rand_seed = random.randint(1, 100000)
result = main(args)
results.append(result)
show_mean_var(results)
else:
main(args)

View File

@@ -0,0 +1,206 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from lfna_utils import lfna_setup
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
# check indexes to be evaluated
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
logger.log(
"Evaluate {:}, which has {:} timestamps in total.".format(
args.srange, len(to_evaluate_indexes)
)
)
w_container_per_epoch = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for i, idx in enumerate(to_evaluate_indexes):
need_time = "Time Left: {:}".format(
convert_secs2time(
per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True
)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx)
+ " "
+ need_time
)
# train the same data
assert idx != 0
historical_x, historical_y = [], []
for past_i in range(idx):
historical_x.append(env_info["{:}-x".format(past_i)])
historical_y.append(env_info["{:}-y".format(past_i)])
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
historical_x, historical_y = subsample(historical_x, historical_y)
# build model
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = MSEMetric()
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
metric = ComposeMetric(MSEMetric(), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, metric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
idx, env_info["total"]
)
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)],
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use all the past data to train.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-all-past-data",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=1000,
help="The total number of epochs.",
)
parser.add_argument(
"--srange", type=str, required=True, help="The range of models to be evaluated"
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}".format(
args.save_dir, args.env_version, args.hidden_dim
)
main(args)

View File

@@ -0,0 +1,207 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/GeMOSA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from lfna_utils import lfna_setup
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
logger, model_kwargs = lfna_setup(args)
w_containers = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx in range(args.prev_time, env_info["total"]):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " "
+ need_time
)
# train the same data
historical_x = env_info["{:}-x".format(idx - args.prev_time)]
historical_y = env_info["{:}-y".format(idx - args.prev_time)]
# build model
model = get_model(**model_kwargs)
print(model)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = MSEMetric()
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
metric = ComposeMetric(MSEMetric(), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, metric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
idx, env_info["total"]
)
w_containers[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)],
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the last timestamp.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-prev-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--prev_time",
type=int,
default=5,
help="The gap between prev_time and current_timestamp",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-d{:}_e{:}_lr{:}-prev{:}-env{:}".format(
args.save_dir,
args.hidden_dim,
args.epochs,
args.init_lr,
args.prev_time,
args.env_version,
)
main(args)

View File

@@ -0,0 +1,228 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/basic-same.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
# python exps/GeMOSA/basic-same.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.metric_utils import (
SaveMetric,
MSEMetric,
Top1AccMetric,
ComposeMetric,
)
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
env = get_synthetic_env(mode=None, version=args.env_version)
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=env.meta_info["input_dim"],
output_dim=env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
logger.log("The total enviornment: {:}".format(env))
w_containers = dict()
if env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " "
+ need_time
)
# train the same data
historical_x = future_x.to(args.device)
historical_y = future_y.to(args.device)
# build model
model = get_model(**model_kwargs)
model = model.to(args.device)
if idx == 0:
print(model)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = metric_cls(True)
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
xmetric = ComposeMetric(metric_cls(True), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
future_x.to(args.device), future_y.to(args.device)
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, xmetric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], results["score"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(idx, len(env))
w_containers[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": future_time.item(),
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-same-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format(
args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version
)
main(args)

View File

@@ -0,0 +1,438 @@
##########################################################
# Learning to Efficiently Generate Models One Step Ahead #
##########################################################
# <----> run on CPU
# python exps/GeMOSA/main.py --env_version v1 --workers 0
# <----> run on a GPU
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
# <----> ablation commands
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda
##########################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
from torch.nn import functional as F
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric
from meta_model import MetaModelV1
from meta_model_ablation import MetaModel_TraditionalAtt
def online_evaluate(
env,
meta_model,
base_model,
criterion,
metric,
args,
logger,
save=False,
easy_adapt=False,
):
logger.log("Online evaluate: {:}".format(env))
metric.reset()
loss_meter = AverageMeter()
w_containers = dict()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
with torch.no_grad():
meta_model.eval()
base_model.eval()
future_time_embed = meta_model.gen_time_embed(
future_time.to(args.device).view(-1)
)
[future_container] = meta_model.gen_model(future_time_embed)
if save:
w_containers[idx] = future_container.no_grad_clone()
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
loss_meter.update(future_loss.item())
# accumulate the metric scores
score = metric(future_y_hat, future_y)
if easy_adapt:
meta_model.easy_adapt(future_time.item(), future_time_embed)
refine, post_refine_loss = False, -1
else:
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
args.refine_lr,
args.refine_epochs,
{"param": future_time_embed, "loss": future_loss.item()},
)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}, score={:.4f}".format(
idx, len(env), future_loss.item(), score
)
+ ", post-loss={:.4f}".format(post_refine_loss if refine else -1)
)
meta_model.clear_fixed()
meta_model.clear_learnt()
return w_containers, loss_meter.avg, metric.get_info()["score"]
def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
base_model.train()
meta_model.train()
optimizer = torch.optim.Adam(
meta_model.get_parameters(True, True, True),
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=True,
)
logger.log("Pre-train the meta-model")
logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
if meta_model.has_best(final_best_name):
meta_model.load_best(final_best_name)
logger.log("Directly load the best model from {:}".format(final_best_name))
return
total_indexes = list(range(meta_model.meta_length))
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
per_epoch_time, start_time = AverageMeter(), time.time()
device = args.device
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
optimizer.zero_grad()
generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps)
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
raw_time_steps = meta_model.meta_timestamps[batch_indexes]
regularization_loss = F.l1_loss(
generated_time_embeds, meta_model.super_meta_embed, reduction="mean"
)
# future loss
total_future_losses, total_present_losses = [], []
future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes])
present_containers = meta_model.gen_model(
meta_model.super_meta_embed[batch_indexes]
)
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
_, (inputs, targets) = xenv(time_step)
inputs, targets = inputs.to(device), targets.to(device)
predictions = base_model.forward_with_container(
inputs, future_containers[ibatch]
)
total_future_losses.append(criterion(predictions, targets))
predictions = base_model.forward_with_container(
inputs, present_containers[ibatch]
)
total_present_losses.append(criterion(predictions, targets))
with torch.no_grad():
meta_std = torch.stack(total_future_losses).std().item()
loss_future = torch.stack(total_future_losses).mean()
loss_present = torch.stack(total_present_losses).mean()
total_loss = loss_future + loss_present + regularization_loss
total_loss.backward()
optimizer.step()
# success
success, best_score = meta_model.save_best(-total_loss.item())
logger.log(
"{:} [META {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format(
time_string(),
iepoch,
args.epochs,
total_loss.item(),
meta_std,
loss_future.item(),
loss_present.item(),
regularization_loss.item(),
)
+ ", batch={:}".format(len(total_future_losses))
+ ", success={:}, best={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh)
+ ", {:}".format(left_time)
)
if success:
last_success_epoch = iepoch
if iepoch - last_success_epoch >= early_stop_thresh:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
meta_model.load_best()
# save to the final model
meta_model.set_best_name(final_best_name)
success, _ = meta_model.save_best(best_score + 1e-6)
assert success
logger.log("Save the best model into {:}".format(final_best_name))
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
test_env = get_synthetic_env(mode="test", version=args.env_version)
all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The trainval enviornment: {:}".format(trainval_env))
logger.log("The total enviornment: {:}".format(all_env))
logger.log("The test enviornment: {:}".format(test_env))
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=all_env.meta_info["input_dim"],
output_dim=all_env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
base_model = get_model(**model_kwargs)
base_model = base_model.to(args.device)
if all_env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric = MSEMetric(True)
elif all_env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric = Top1AccMetric(True)
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
shape_container = base_model.get_w_container().to_shape_container()
# pre-train the hypernetwork
timestamps = trainval_env.get_timestamp(None)
if args.ablation is None:
MetaModel_cls = MetaModelV1
elif args.ablation == "old":
MetaModel_cls = MetaModel_TraditionalAtt
else:
raise ValueError("Unknown ablation : {:}".format(args.ablation))
meta_model = MetaModel_cls(
shape_container,
args.layer_dim,
args.time_dim,
timestamps,
seq_length=args.seq_length,
interval=trainval_env.time_interval,
)
meta_model = meta_model.to(args.device)
logger.log("The base-model has {:} weights.".format(base_model.numel()))
logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
meta_train_procedure(base_model, meta_model, criterion, trainval_env, args, logger)
# try to evaluate once
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
"""
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
"""
w_containers_care_adapt, loss_adapt_v1, metric_adapt_v1 = online_evaluate(
test_env, meta_model, base_model, criterion, metric, args, logger, True, False
)
w_containers_easy_adapt, loss_adapt_v2, metric_adapt_v2 = online_evaluate(
test_env, meta_model, base_model, criterion, metric, args, logger, True, True
)
logger.log(
"[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format(
loss_adapt_v1, metric_adapt_v1
)
)
logger.log(
"[Easy-Adapt] loss = {:.6f}, metric = {:.6f}".format(
loss_adapt_v2, metric_adapt_v2
)
)
save_checkpoint(
{
"w_containers_care_adapt": w_containers_care_adapt,
"w_containers_easy_adapt": w_containers_easy_adapt,
"test_loss_adapt_v1": loss_adapt_v1,
"test_loss_adapt_v2": loss_adapt_v2,
"test_metric_adapt_v1": metric_adapt_v1,
"test_metric_adapt_v2": metric_adapt_v2,
},
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(".")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/GeMOSA",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
default=16,
help="The hidden dimension.",
)
parser.add_argument(
"--layer_dim",
type=int,
default=16,
help="The layer chunk dimension.",
)
parser.add_argument(
"--time_dim",
type=int,
default=16,
help="The timestamp dimension.",
)
#####
parser.add_argument(
"--lr",
type=float,
default=0.002,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.00001,
help="The weight decay for the optimizer (default is Adam)",
)
parser.add_argument(
"--meta_batch",
type=int,
default=64,
help="The batch size for the meta-model",
)
parser.add_argument(
"--sampler_enlarge",
type=int,
default=5,
help="Enlarge the #iterations for an epoch",
)
parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.")
parser.add_argument(
"--refine_lr",
type=float,
default=0.001,
help="The learning rate for the optimizer, during refine",
)
parser.add_argument(
"--refine_epochs", type=int, default=150, help="The final refine #epochs."
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=20,
help="The #epochs for early stop.",
)
parser.add_argument(
"--pretrain_early_stop_thresh",
type=int,
default=300,
help="The #epochs for early stop.",
)
parser.add_argument(
"--seq_length", type=int, default=10, help="The sequence length."
)
parser.add_argument(
"--workers", type=int, default=4, help="The number of workers in parallel."
)
parser.add_argument(
"--ablation", type=str, default=None, help="The ablation indicator."
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
if args.ablation is None:
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
args.save_dir,
args.meta_batch,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.seq_length,
args.lr,
args.weight_decay,
args.epochs,
args.env_version,
)
else:
args.save_dir = (
"{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format(
args.save_dir,
args.meta_batch,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.seq_length,
args.lr,
args.weight_decay,
args.epochs,
args.ablation,
args.env_version,
)
)
main(args)

View File

@@ -0,0 +1,257 @@
import torch
import torch.nn.functional as F
from xautodl.xlayers import super_core
from xautodl.xlayers import trunc_normal_
from xautodl.models.xcore import get_model
class MetaModelV1(super_core.SuperModule):
"""Learning to Generate Models One Step Ahead (Meta Model Design)."""
def __init__(
self,
shape_container,
layer_dim,
time_dim,
meta_timestamps,
dropout: float = 0.1,
seq_length: int = None,
interval: float = None,
thresh: float = None,
):
super(MetaModelV1, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self._raw_meta_timestamps = meta_timestamps
assert interval is not None
self._interval = interval
self._thresh = interval * seq_length if thresh is None else thresh
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
)
self.register_parameter(
"_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
)
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
self._time_embed_dim = time_dim
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._tscalar_embed = super_core.SuperDynamicPositionE(
time_dim, scale=1 / interval
)
# build transformer
self._trans_att = super_core.SuperQKVAttentionV2(
qk_att_dim=time_dim,
in_v_dim=time_dim,
hidden_dim=time_dim,
num_heads=4,
proj_dim=time_dim,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_dim + time_dim,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_dim + time_dim) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=dropout,
)
self._generator = get_model(**model_kwargs)
# initialization
trunc_normal_(
[self._super_layer_embed, self._super_meta_embed],
std=0.02,
)
def get_parameters(self, time_embed, attention, generator):
parameters = []
if time_embed:
parameters.append(self._super_meta_embed)
if attention:
parameters.extend(list(self._trans_att.parameters()))
if generator:
parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters()))
return parameters
@property
def meta_timestamps(self):
with torch.no_grad():
meta_timestamps = [self._meta_timestamps]
for key in ("fixed", "learnt"):
if self._append_meta_timestamps[key] is not None:
meta_timestamps.append(self._append_meta_timestamps[key])
return torch.cat(meta_timestamps)
@property
def super_meta_embed(self):
meta_embed = [self._super_meta_embed]
for key in ("fixed", "learnt"):
if self._append_meta_embed[key] is not None:
meta_embed.append(self._append_meta_embed[key])
return torch.cat(meta_embed)
def create_meta_embed(self):
param = torch.Tensor(1, self._time_embed_dim)
trunc_normal_(param, std=0.02)
param = param.to(self._super_meta_embed.device)
param = torch.nn.Parameter(param, True)
return param
def get_closest_meta_distance(self, timestamp):
with torch.no_grad():
distances = torch.abs(self.meta_timestamps - timestamp)
return torch.min(distances).item()
def replace_append_learnt(self, timestamp, meta_embed):
self._append_meta_timestamps["learnt"] = timestamp
self._append_meta_embed["learnt"] = meta_embed
@property
def meta_length(self):
return self.meta_timestamps.numel()
def clear_fixed(self):
self._append_meta_timestamps["fixed"] = None
self._append_meta_embed["fixed"] = None
def clear_learnt(self):
self.replace_append_learnt(None, None)
def append_fixed(self, timestamp, meta_embed):
with torch.no_grad():
device = self._super_meta_embed.device
timestamp = timestamp.detach().clone().to(device)
meta_embed = meta_embed.detach().clone().to(device)
if self._append_meta_timestamps["fixed"] is None:
self._append_meta_timestamps["fixed"] = timestamp
else:
self._append_meta_timestamps["fixed"] = torch.cat(
(self._append_meta_timestamps["fixed"], timestamp), dim=0
)
if self._append_meta_embed["fixed"] is None:
self._append_meta_embed["fixed"] = meta_embed
else:
self._append_meta_embed["fixed"] = torch.cat(
(self._append_meta_embed["fixed"], meta_embed), dim=0
)
def gen_time_embed(self, timestamps):
# timestamps is a batch of timestamps
[B] = timestamps.shape
# batch, seq = timestamps.shape
timestamps = timestamps.view(-1, 1)
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_qk_att_embed = self._tscalar_embed(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
)
# create the mask
mask = (
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
) | (
torch.abs(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1)
)
> self._thresh
)
timestamp_embeds = self._trans_att(
timestamp_qk_att_embed,
timestamp_v_embed,
mask,
)
return timestamp_embeds[:, -1, :]
def gen_model(self, time_embeds):
B, _ = time_embeds.shape
# create joint embed
num_layer, _ = self._super_layer_embed.shape
# The shape of `joint_embed` is batch * num-layers * input-dim
joint_embeds = torch.cat(
(
time_embeds.view(B, 1, -1).expand(-1, num_layer, -1),
self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1),
),
dim=-1,
)
batch_weights = self._generator(joint_embeds)
batch_containers = []
for weights in torch.split(batch_weights, 1):
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
return batch_containers
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
raise NotImplementedError
def forward_candidate(self, input):
raise NotImplementedError
def easy_adapt(self, timestamp, time_embed):
with torch.no_grad():
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, time_embed)
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
distance = self.get_closest_meta_distance(timestamp)
if distance + self._interval * 1e-2 <= self._interval:
return False, None
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
with torch.set_grad_enabled(True):
new_param = self.create_meta_embed()
optimizer = torch.optim.Adam(
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
)
timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param)
self.train()
base_model.train()
if init_info is not None:
best_loss = init_info["loss"]
new_param.data.copy_(init_info["param"].data)
else:
best_loss = 1e9
with torch.no_grad():
best_new_param = new_param.detach().clone()
for iepoch in range(epochs):
optimizer.zero_grad()
time_embed = self.gen_time_embed(timestamp.view(1))
match_loss = F.l1_loss(new_param, time_embed)
[container] = self.gen_model(new_param.view(1, -1))
y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss
loss.backward()
optimizer.step()
if meta_loss.item() < best_loss:
with torch.no_grad():
best_loss = meta_loss.item()
best_new_param = new_param.detach().clone()
self.easy_adapt(timestamp, best_new_param)
return True, best_loss
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
list(self._super_layer_embed.shape),
list(self._super_meta_embed.shape),
list(self._meta_timestamps.shape),
)

View File

@@ -0,0 +1,260 @@
#
# This is used for the ablation studies:
# The meta-model in this file uses the traditional attention in
# transformer.
#
import torch
import torch.nn.functional as F
from xautodl.xlayers import super_core
from xautodl.xlayers import trunc_normal_
from xautodl.models.xcore import get_model
class MetaModel_TraditionalAtt(super_core.SuperModule):
"""Learning to Generate Models One Step Ahead (Meta Model Design)."""
def __init__(
self,
shape_container,
layer_dim,
time_dim,
meta_timestamps,
dropout: float = 0.1,
seq_length: int = None,
interval: float = None,
thresh: float = None,
):
super(MetaModel_TraditionalAtt, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self._raw_meta_timestamps = meta_timestamps
assert interval is not None
self._interval = interval
self._thresh = interval * seq_length if thresh is None else thresh
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
)
self.register_parameter(
"_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
)
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
self._time_embed_dim = time_dim
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._tscalar_embed = super_core.SuperDynamicPositionE(
time_dim, scale=1 / interval
)
# build transformer
self._trans_att = super_core.SuperQKVAttention(
in_q_dim=time_dim,
in_k_dim=time_dim,
in_v_dim=time_dim,
num_heads=4,
proj_dim=time_dim,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_dim + time_dim,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_dim + time_dim) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=dropout,
)
self._generator = get_model(**model_kwargs)
# initialization
trunc_normal_(
[self._super_layer_embed, self._super_meta_embed],
std=0.02,
)
def get_parameters(self, time_embed, attention, generator):
parameters = []
if time_embed:
parameters.append(self._super_meta_embed)
if attention:
parameters.extend(list(self._trans_att.parameters()))
if generator:
parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters()))
return parameters
@property
def meta_timestamps(self):
with torch.no_grad():
meta_timestamps = [self._meta_timestamps]
for key in ("fixed", "learnt"):
if self._append_meta_timestamps[key] is not None:
meta_timestamps.append(self._append_meta_timestamps[key])
return torch.cat(meta_timestamps)
@property
def super_meta_embed(self):
meta_embed = [self._super_meta_embed]
for key in ("fixed", "learnt"):
if self._append_meta_embed[key] is not None:
meta_embed.append(self._append_meta_embed[key])
return torch.cat(meta_embed)
def create_meta_embed(self):
param = torch.Tensor(1, self._time_embed_dim)
trunc_normal_(param, std=0.02)
param = param.to(self._super_meta_embed.device)
param = torch.nn.Parameter(param, True)
return param
def get_closest_meta_distance(self, timestamp):
with torch.no_grad():
distances = torch.abs(self.meta_timestamps - timestamp)
return torch.min(distances).item()
def replace_append_learnt(self, timestamp, meta_embed):
self._append_meta_timestamps["learnt"] = timestamp
self._append_meta_embed["learnt"] = meta_embed
@property
def meta_length(self):
return self.meta_timestamps.numel()
def clear_fixed(self):
self._append_meta_timestamps["fixed"] = None
self._append_meta_embed["fixed"] = None
def clear_learnt(self):
self.replace_append_learnt(None, None)
def append_fixed(self, timestamp, meta_embed):
with torch.no_grad():
device = self._super_meta_embed.device
timestamp = timestamp.detach().clone().to(device)
meta_embed = meta_embed.detach().clone().to(device)
if self._append_meta_timestamps["fixed"] is None:
self._append_meta_timestamps["fixed"] = timestamp
else:
self._append_meta_timestamps["fixed"] = torch.cat(
(self._append_meta_timestamps["fixed"], timestamp), dim=0
)
if self._append_meta_embed["fixed"] is None:
self._append_meta_embed["fixed"] = meta_embed
else:
self._append_meta_embed["fixed"] = torch.cat(
(self._append_meta_embed["fixed"], meta_embed), dim=0
)
def gen_time_embed(self, timestamps):
# timestamps is a batch of timestamps
[B] = timestamps.shape
# batch, seq = timestamps.shape
timestamps = timestamps.view(-1, 1)
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_q_embed = self._tscalar_embed(timestamps)
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
# create the mask
mask = (
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
) | (
torch.abs(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1)
)
> self._thresh
)
timestamp_embeds = self._trans_att(
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
)
return timestamp_embeds[:, -1, :]
def gen_model(self, time_embeds):
B, _ = time_embeds.shape
# create joint embed
num_layer, _ = self._super_layer_embed.shape
# The shape of `joint_embed` is batch * num-layers * input-dim
joint_embeds = torch.cat(
(
time_embeds.view(B, 1, -1).expand(-1, num_layer, -1),
self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1),
),
dim=-1,
)
batch_weights = self._generator(joint_embeds)
batch_containers = []
for weights in torch.split(batch_weights, 1):
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
return batch_containers
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
raise NotImplementedError
def forward_candidate(self, input):
raise NotImplementedError
def easy_adapt(self, timestamp, time_embed):
with torch.no_grad():
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, time_embed)
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
distance = self.get_closest_meta_distance(timestamp)
if distance + self._interval * 1e-2 <= self._interval:
return False, None
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
with torch.set_grad_enabled(True):
new_param = self.create_meta_embed()
optimizer = torch.optim.Adam(
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
)
timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param)
self.train()
base_model.train()
if init_info is not None:
best_loss = init_info["loss"]
new_param.data.copy_(init_info["param"].data)
else:
best_loss = 1e9
with torch.no_grad():
best_new_param = new_param.detach().clone()
for iepoch in range(epochs):
optimizer.zero_grad()
time_embed = self.gen_time_embed(timestamp.view(1))
match_loss = F.l1_loss(new_param, time_embed)
[container] = self.gen_model(new_param.view(1, -1))
y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss
loss.backward()
optimizer.step()
if meta_loss.item() < best_loss:
with torch.no_grad():
best_loss = meta_loss.item()
best_new_param = new_param.detach().clone()
self.easy_adapt(timestamp, best_new_param)
return True, best_loss
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
list(self._super_layer_embed.shape),
list(self._super_meta_embed.shape),
list(self._meta_timestamps.shape),
)

View File

@@ -0,0 +1,441 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
############################################################################
# python exps/GeMOSA/vis-synthetic.py --env_version v1 #
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
# python exps/GeMOSA/vis-synthetic.py --env_version v3 #
# python exps/GeMOSA/vis-synthetic.py --env_version v4 #
############################################################################
import os, sys, copy, random
import torch
import numpy as np
import argparse
from collections import OrderedDict, defaultdict
from pathlib import Path
from tqdm import tqdm
from pprint import pprint
import matplotlib
from matplotlib import cm
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.models.xcore import get_model
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.procedures.metric_utils import MSEMetric
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label)
cur_ax.scatter(
xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None
)
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
save_path = save_dir / "{:04d}".format(timestamp)
# print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))
dpi, width, height = 40, wh[0], wh[1]
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
fig = plt.figure(figsize=figsize)
if fig_title is not None:
fig.suptitle(
fig_title, fontsize=LegendFontsize, fontweight="bold", x=0.5, y=0.92
)
for idx, scatter_dict in enumerate(scatter_list):
cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1)
plot_scatter(
cur_ax,
scatter_dict["xaxis"],
scatter_dict["yaxis"],
scatter_dict["color"],
scatter_dict["alpha"],
scatter_dict["linewidths"],
scatter_dict["label"],
)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
cur_ax.set_xlim(scatter_dict["xlim"][0], scatter_dict["xlim"][1])
cur_ax.set_ylim(scatter_dict["ylim"][0], scatter_dict["ylim"][1])
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
def find_min(cur, others):
if cur is None:
return float(others)
else:
return float(min(cur, others))
def find_max(cur, others):
if cur is None:
return float(others.max())
else:
return float(max(cur, others))
def compare_cl(save_dir):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env, cl_function = create_example_v1(
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
timestamp_config=dict(num=200),
num_per_task=1000,
)
models = dict()
cl_function.set_timestamp(0)
cl_xaxis_min = None
cl_xaxis_max = None
all_data = OrderedDict()
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
xaxis_all = dataset[0][:, 0].numpy()
yaxis_all = dataset[1][:, 0].numpy()
current_data = dict()
current_data["lfna_xaxis_all"] = xaxis_all
current_data["lfna_yaxis_all"] = yaxis_all
# compute cl-min
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std())
all_data[timestamp] = current_data
global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1)
global_cl_yaxis_all = cl_function.noise_call(global_cl_xaxis_all)
for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)):
scatter_list = []
scatter_list.append(
{
"xaxis": xdata["lfna_xaxis_all"],
"yaxis": xdata["lfna_yaxis_all"],
"color": "k",
"linewidths": 15,
"alpha": 0.99,
"xlim": (-6, 6),
"ylim": (-40, 40),
"label": "LFNA",
}
)
cur_cl_xaxis_min = cl_xaxis_min
cur_cl_xaxis_max = cl_xaxis_min + (cl_xaxis_max - cl_xaxis_min) * (
idx + 1
) / len(all_data)
cl_xaxis_all = np.arange(cur_cl_xaxis_min, cur_cl_xaxis_max, step=0.01)
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all, std=0.2)
scatter_list.append(
{
"xaxis": cl_xaxis_all,
"yaxis": cl_yaxis_all,
"color": "k",
"linewidths": 15,
"xlim": (round(cl_xaxis_min, 1), round(cl_xaxis_max, 1)),
"ylim": (-20, 6),
"alpha": 0.99,
"label": "Continual Learning",
}
)
draw_multi_fig(
save_dir,
idx,
scatter_list,
wh=(2200, 1800),
fig_title="Timestamp={:03d}".format(idx),
)
print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve()
base_cmd = (
"ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format(
xdir=save_dir
)
)
video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format(
base_cmd, xdir=save_dir
)
print(video_cmd + "\n")
os.system(video_cmd)
os.system(
"{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir)
)
def visualize_env(save_dir, version):
save_dir = Path(str(save_dir))
for substr in ("pdf", "png"):
sub_save_dir = save_dir / "{:}-{:}".format(substr, version)
sub_save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env = get_synthetic_env(version=version)
print("env: {:}".format(dynamic_env))
print("oracle_map: {:}".format(dynamic_env.oracle_map))
allxs, allys = [], []
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
if dynamic_env.meta_info["task"] == "regression":
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
print(
"x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())
)
print(
"y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())
)
elif dynamic_env.meta_info["task"] == "classification":
allxs = torch.cat(allxs)
print(
"x[0] - min={:.3f}, max={:.3f}".format(
allxs[:, 0].min().item(), allxs[:, 0].max().item()
)
)
print(
"x[1] - min={:.3f}, max={:.3f}".format(
allxs[:, 1].min().item(), allxs[:, 1].max().item()
)
)
else:
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(1, 1, 1)
if dynamic_env.meta_info["task"] == "regression":
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
plot_scatter(
cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)
)
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
elif dynamic_env.meta_info["task"] == "classification":
positive, negative = ally == 1, ally == 0
# plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx))
plot_scatter(
cur_ax,
allx[positive, 0],
allx[positive, 1],
"r",
0.99,
(20, 10),
"positive",
)
plot_scatter(
cur_ax,
allx[negative, 0],
allx[negative, 1],
"g",
0.99,
(20, 10),
"negative",
)
cur_ax.set_xlim(
round(allxs[:, 0].min().item(), 1), round(allxs[:, 0].max().item(), 1)
)
cur_ax.set_ylim(
round(allxs[:, 1].min().item(), 1), round(allxs[:, 1].max().item(), 1)
)
else:
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = (
save_dir
/ "pdf-{:}".format(version)
/ "v{:}-{:05d}.pdf".format(version, idx)
)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = (
save_dir
/ "png-{:}".format(version)
/ "v{:}-{:05d}.png".format(version, idx)
)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir / "png-{:}".format(version), version=version
)
print(base_cmd)
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version))
def compare_algs(save_dir, version, alg_dir="./outputs/GeMOSA-synthetic"):
save_dir = Path(str(save_dir))
for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr
sub_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 30, 3200, 2000
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
dynamic_env = get_synthetic_env(mode=None, version=version)
allxs, allys = [], []
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
alg_name2dir = OrderedDict()
# alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
# alg_name2dir["MAML"] = "use-maml-s1"
# alg_name2dir["LFNA (fix init)"] = "lfna-fix-init"
if version == "v1":
# alg_name2dir["Optimal"] = "use-same-timestamp"
alg_name2dir[
"GMOA"
] = "lfna-battle-bs128-d16_16_16-s16-lr0.002-wd1e-05-e10000-envv1"
else:
raise ValueError("Invalid version: {:}".format(version))
alg_name2all_containers = OrderedDict()
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
ckp_path = Path(alg_dir) / str(xdir) / "final-ckp.pth"
xdata = torch.load(ckp_path, map_location="cpu")
alg_name2all_containers[alg] = xdata["w_containers"]
# load the basic model
model = get_model(
dict(model_type="norm_mlp"),
input_dim=1,
output_dim=1,
hidden_dims=[16] * 2,
act_cls="gelu",
norm_cls="layer_norm_1d",
)
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
colors = ["r", "g", "b", "m", "y"]
linewidths, skip = 10, 5
for idx, (timestamp, (ori_allx, ori_ally)) in enumerate(
tqdm(dynamic_env, ncols=50)
):
if idx <= skip:
continue
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(2, 1, 1)
# the data
allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy()
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
with torch.no_grad():
predicts = model.forward_with_container(
ori_allx, alg_name2all_containers[alg][idx]
)
predicts = predicts.cpu()
# keep data
metric = MSEMetric()
metric(predicts, ori_ally)
predicts = predicts.view(-1).numpy()
alg2xs[alg].append(idx)
alg2ys[alg].append(metric.get_info()["mse"])
plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
cur_ax.legend(loc=1, fontsize=LegendFontsize)
# the trajectory data
cur_ax = fig.add_subplot(2, 1, 2)
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
# plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg)
cur_ax.plot(
alg2xs[alg],
alg2ys[alg],
color=colors[idx_alg],
linestyle="-",
linewidth=5,
label=alg,
)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
cur_ax.set_xlabel("Timestamp", fontsize=LabelSize)
cur_ax.set_ylabel("MSE", fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(1, len(dynamic_env))
cur_ax.set_ylim(0, 10)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx - skip)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx - skip)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir / "png", w=width, h=height, ver=version
)
os.system(
"{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)
)
os.system(
"{:} {xdir}/com-alg-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Visualize synthetic data.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/vis-synthetic",
help="The save directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
args = parser.parse_args()
visualize_env(os.path.join(args.save_dir, "vis-env"), args.env_version)
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
# compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_cl(os.path.join(args.save_dir, "compare-cl"))