Update codes

This commit is contained in:
D-X-Y
2021-05-26 02:41:36 +00:00
parent f8350d00ed
commit 9057011781
11 changed files with 42 additions and 800 deletions

206
exps/GeMOSA/basic-his.py Normal file
View File

@@ -0,0 +1,206 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from lfna_utils import lfna_setup
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
# check indexes to be evaluated
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
logger.log(
"Evaluate {:}, which has {:} timestamps in total.".format(
args.srange, len(to_evaluate_indexes)
)
)
w_container_per_epoch = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for i, idx in enumerate(to_evaluate_indexes):
need_time = "Time Left: {:}".format(
convert_secs2time(
per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True
)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx)
+ " "
+ need_time
)
# train the same data
assert idx != 0
historical_x, historical_y = [], []
for past_i in range(idx):
historical_x.append(env_info["{:}-x".format(past_i)])
historical_y.append(env_info["{:}-y".format(past_i)])
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
historical_x, historical_y = subsample(historical_x, historical_y)
# build model
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = MSEMetric()
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
metric = ComposeMetric(MSEMetric(), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, metric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
idx, env_info["total"]
)
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)],
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use all the past data to train.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-all-past-data",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=1000,
help="The total number of epochs.",
)
parser.add_argument(
"--srange", type=str, required=True, help="The range of models to be evaluated"
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}".format(
args.save_dir, args.env_version, args.hidden_dim
)
main(args)

271
exps/GeMOSA/basic-maml.py Normal file
View File

@@ -0,0 +1,271 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-maml.py --env_version v1 --inner_step 5
# python exps/LFNA/basic-maml.py --env_version v2
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from log_utils import time_string
from log_utils import AverageMeter, convert_secs2time
from utils import split_str2indexes
from procedures.advanced_main import basic_train_fn, basic_eval_fn
from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from datasets.synthetic_core import get_synthetic_env, EnvSampler
from models.xcore import get_model
from xlayers import super_core
from lfna_utils import lfna_setup, TimeData
class MAML:
"""A LFNA meta-model that uses the MLP as delta-net."""
def __init__(
self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1
):
self.criterion = criterion
# self.container = container
self.network = network
self.meta_optimizer = torch.optim.Adam(
self.network.parameters(), lr=meta_lr, amsgrad=True
)
self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.meta_optimizer,
milestones=[
int(epochs * 0.8),
int(epochs * 0.9),
],
gamma=0.1,
)
self.inner_lr = inner_lr
self.inner_step = inner_step
self._best_info = dict(state_dict=None, iepoch=None, score=None)
print("There are {:} weights.".format(self.network.get_w_container().numel()))
def adapt(self, dataset):
# create a container for the future timestamp
container = self.network.get_w_container()
for k in range(0, self.inner_step):
y_hat = self.network.forward_with_container(dataset.x, container)
loss = self.criterion(y_hat, dataset.y)
grads = torch.autograd.grad(loss, container.parameters())
container = container.additive([-self.inner_lr * grad for grad in grads])
return container
def predict(self, x, container=None):
if container is not None:
y_hat = self.network.forward_with_container(x, container)
else:
y_hat = self.network(x)
return y_hat
def step(self):
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
self.meta_optimizer.step()
self.meta_lr_scheduler.step()
def zero_grad(self):
self.meta_optimizer.zero_grad()
def load_state_dict(self, state_dict):
self.criterion.load_state_dict(state_dict["criterion"])
self.network.load_state_dict(state_dict["network"])
self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"])
self.meta_lr_scheduler.load_state_dict(state_dict["meta_lr_scheduler"])
def state_dict(self):
state_dict = dict()
state_dict["criterion"] = self.criterion.state_dict()
state_dict["network"] = self.network.state_dict()
state_dict["meta_optimizer"] = self.meta_optimizer.state_dict()
state_dict["meta_lr_scheduler"] = self.meta_lr_scheduler.state_dict()
return state_dict
def save_best(self, score):
success, best_score = self.network.save_best(score)
return success, best_score
def load_best(self):
self.network.load_best()
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
model = get_model(**model_kwargs)
dynamic_env = get_synthetic_env(mode="train", version=args.env_version)
criterion = torch.nn.MSELoss()
maml = MAML(
model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step
)
# meta-training
last_success_epoch = 0
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
head_str = (
"[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs)
+ need_time
)
maml.zero_grad()
meta_losses = []
for ibatch in range(args.meta_batch):
future_timestamp = dynamic_env.random_timestamp()
_, (future_x, future_y) = dynamic_env(future_timestamp)
past_timestamp = (
future_timestamp - args.prev_time * dynamic_env.timestamp_interval
)
_, (past_x, past_y) = dynamic_env(past_timestamp)
future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y))
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
meta_loss = torch.stack(meta_losses).mean()
meta_loss.backward()
maml.step()
logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item()))
success, best_score = maml.save_best(-meta_loss.item())
if success:
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
save_checkpoint(maml.state_dict(), logger.path("model"), logger)
last_success_epoch = iepoch
if iepoch - last_success_epoch >= args.early_stop_thresh:
logger.log("Early stop at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
# meta-test
maml.load_best()
eval_env = env_info["dynamic_env"]
assert eval_env.timestamp_interval == dynamic_env.timestamp_interval
w_container_per_epoch = dict()
for idx in range(args.prev_time, len(eval_env)):
future_timestamp, (future_x, future_y) = eval_env[idx]
past_timestamp = (
future_timestamp.item() - args.prev_time * eval_env.timestamp_interval
)
_, (past_x, past_y) = eval_env(past_timestamp)
future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y))
w_container_per_epoch[idx] = future_container.no_grad_clone()
with torch.no_grad():
future_y_hat = maml.predict(future_x, w_container_per_epoch[idx])
future_loss = maml.criterion(future_y_hat, future_y)
logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()))
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-maml",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
default=16,
help="The hidden dimension.",
)
parser.add_argument(
"--meta_lr",
type=float,
default=0.01,
help="The learning rate for the MAML optimizer (default is Adam)",
)
parser.add_argument(
"--fail_thresh",
type=float,
default=1000,
help="The threshold for the failure, which we reuse the previous best model",
)
parser.add_argument(
"--inner_lr",
type=float,
default=0.005,
help="The learning rate for the inner optimization",
)
parser.add_argument(
"--inner_step", type=int, default=1, help="The inner loop steps for MAML."
)
parser.add_argument(
"--prev_time",
type=int,
default=5,
help="The gap between prev_time and current_timestamp",
)
parser.add_argument(
"--meta_batch",
type=int,
default=64,
help="The batch size for the meta-model",
)
parser.add_argument(
"--epochs",
type=int,
default=2000,
help="The total number of epochs.",
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=50,
help="The maximum epochs for early stop.",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-s{:}-mlr{:}-d{:}-prev{:}-e{:}-env{:}".format(
args.save_dir,
args.inner_step,
args.meta_lr,
args.hidden_dim,
args.prev_time,
args.epochs,
args.env_version,
)
main(args)

203
exps/GeMOSA/basic-prev.py Normal file
View File

@@ -0,0 +1,203 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from lfna_utils import lfna_setup
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
w_container_per_epoch = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx in range(args.prev_time, env_info["total"]):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " "
+ need_time
)
# train the same data
historical_x = env_info["{:}-x".format(idx - args.prev_time)]
historical_y = env_info["{:}-y".format(idx - args.prev_time)]
# build model
model = get_model(**model_kwargs)
print(model)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = MSEMetric()
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
metric = ComposeMetric(MSEMetric(), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, metric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
idx, env_info["total"]
)
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)],
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the last timestamp.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-prev-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--prev_time",
type=int,
default=5,
help="The gap between prev_time and current_timestamp",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-d{:}_e{:}_lr{:}-prev{:}-env{:}".format(
args.save_dir,
args.hidden_dim,
args.epochs,
args.init_lr,
args.prev_time,
args.env_version,
)
main(args)

204
exps/GeMOSA/basic-same.py Normal file
View File

@@ -0,0 +1,204 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from lfna_utils import lfna_setup
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
logger, model_kwargs = lfna_setup(args)
env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The total enviornment: {:}".format(env))
w_containers = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " "
+ need_time
)
# train the same data
historical_x = future_x.to(args.device)
historical_y = future_y.to(args.device)
# build model
model = get_model(**model_kwargs)
model = model.to(args.device)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = MSEMetric()
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
metric = ComposeMetric(MSEMetric(), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
future_x.to(args.device), future_y.to(args.device)
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, metric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(idx, len(env))
w_containers[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": future_time.item(),
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-same-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format(
args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version
)
main(args)

View File

@@ -0,0 +1,265 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import torch
import torch.nn.functional as F
from xautodl.xlayers import super_core
from xautodl.xlayers import trunc_normal_
from xautodl.models.xcore import get_model
class MetaModelV1(super_core.SuperModule):
"""Learning to Generate Models One Step Ahead (Meta Model Design)."""
def __init__(
self,
shape_container,
layer_dim,
time_dim,
meta_timestamps,
dropout: float = 0.1,
seq_length: int = 10,
interval: float = None,
thresh: float = None,
):
super(MetaModelV1, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self._raw_meta_timestamps = meta_timestamps
assert interval is not None
self._interval = interval
self._seq_length = seq_length
self._thresh = interval * 50 if thresh is None else thresh
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
)
self.register_parameter(
"_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
)
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
# register a time difference buffer
time_interval = [-i * self._interval for i in range(self._seq_length)]
time_interval.reverse()
self.register_buffer("_time_interval", torch.Tensor(time_interval))
self._time_embed_dim = time_dim
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._tscalar_embed = super_core.SuperDynamicPositionE(
time_dim, scale=1 / interval
)
# build transformer
self._trans_att = super_core.SuperQKVAttentionV2(
qk_att_dim=time_dim,
in_v_dim=time_dim,
hidden_dim=time_dim,
num_heads=4,
proj_dim=time_dim,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_dim + time_dim,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_dim + time_dim) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=dropout,
)
self._generator = get_model(**model_kwargs)
# initialization
trunc_normal_(
[self._super_layer_embed, self._super_meta_embed],
std=0.02,
)
def get_parameters(self, time_embed, attention, generator):
parameters = []
if time_embed:
parameters.append(self._super_meta_embed)
if attention:
parameters.extend(list(self._trans_att.parameters()))
if generator:
parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters()))
return parameters
@property
def meta_timestamps(self):
with torch.no_grad():
meta_timestamps = [self._meta_timestamps]
for key in ("fixed", "learnt"):
if self._append_meta_timestamps[key] is not None:
meta_timestamps.append(self._append_meta_timestamps[key])
return torch.cat(meta_timestamps)
@property
def super_meta_embed(self):
meta_embed = [self._super_meta_embed]
for key in ("fixed", "learnt"):
if self._append_meta_embed[key] is not None:
meta_embed.append(self._append_meta_embed[key])
return torch.cat(meta_embed)
def create_meta_embed(self):
param = torch.Tensor(1, self._time_embed_dim)
trunc_normal_(param, std=0.02)
param = param.to(self._super_meta_embed.device)
param = torch.nn.Parameter(param, True)
return param
def get_closest_meta_distance(self, timestamp):
with torch.no_grad():
distances = torch.abs(self.meta_timestamps - timestamp)
return torch.min(distances).item()
def replace_append_learnt(self, timestamp, meta_embed):
self._append_meta_timestamps["learnt"] = timestamp
self._append_meta_embed["learnt"] = meta_embed
@property
def meta_length(self):
return self.meta_timestamps.numel()
def clear_fixed(self):
self._append_meta_timestamps["fixed"] = None
self._append_meta_embed["fixed"] = None
def clear_learnt(self):
self.replace_append_learnt(None, None)
def append_fixed(self, timestamp, meta_embed):
with torch.no_grad():
device = self._super_meta_embed.device
timestamp = timestamp.detach().clone().to(device)
meta_embed = meta_embed.detach().clone().to(device)
if self._append_meta_timestamps["fixed"] is None:
self._append_meta_timestamps["fixed"] = timestamp
else:
self._append_meta_timestamps["fixed"] = torch.cat(
(self._append_meta_timestamps["fixed"], timestamp), dim=0
)
if self._append_meta_embed["fixed"] is None:
self._append_meta_embed["fixed"] = meta_embed
else:
self._append_meta_embed["fixed"] = torch.cat(
(self._append_meta_embed["fixed"], meta_embed), dim=0
)
def _obtain_time_embed(self, timestamps):
# timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_qk_att_embed = self._tscalar_embed(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
)
# create the mask
mask = (
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
) | (
torch.abs(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1)
)
> self._thresh
)
timestamp_embeds = self._trans_att(
timestamp_qk_att_embed,
timestamp_v_embed,
mask,
)
return timestamp_embeds[:, -1, :]
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
if time_embeds is None:
time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1)
B, S = time_seq.shape
time_embeds = self._obtain_time_embed(time_seq)
else: # use the hyper-net only
time_seq = None
B, _ = time_embeds.shape
if tembed_only:
return time_embeds
# create joint embed
num_layer, _ = self._super_layer_embed.shape
# The shape of `joint_embed` is batch * num-layers * input-dim
joint_embeds = torch.cat(
(
time_embeds.view(B, 1, -1).expand(-1, num_layer, -1),
self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1),
),
dim=-1,
)
batch_weights = self._generator(joint_embeds)
batch_containers = []
for weights in torch.split(batch_weights, 1):
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
return time_seq, batch_containers, time_embeds
def forward_candidate(self, input):
raise NotImplementedError
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
distance = self.get_closest_meta_distance(timestamp)
if distance + self._interval * 1e-2 <= self._interval:
return False, None
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
with torch.set_grad_enabled(True):
new_param = self.create_meta_embed()
optimizer = torch.optim.Adam(
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
)
timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param)
self.train()
base_model.train()
if init_info is not None:
best_loss = init_info["loss"]
new_param.data.copy_(init_info["param"].data)
else:
best_loss = 1e9
with torch.no_grad():
best_new_param = new_param.detach().clone()
for iepoch in range(epochs):
optimizer.zero_grad()
_, [_], time_embed = self(timestamp.view(1, 1), None)
match_loss = criterion(new_param, time_embed)
_, [container], time_embed = self(None, new_param.view(1, 1, -1))
y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss
loss.backward()
optimizer.step()
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
if meta_loss.item() < best_loss:
with torch.no_grad():
best_loss = meta_loss.item()
best_new_param = new_param.detach().clone()
with torch.no_grad():
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, best_new_param)
return True, best_loss
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
list(self._super_layer_embed.shape),
list(self._super_meta_embed.shape),
list(self._meta_timestamps.shape),
)

117
exps/GeMOSA/lfna_models.py Normal file
View File

@@ -0,0 +1,117 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import copy
import torch
import torch.nn.functional as F
from xlayers import super_core
from xlayers import trunc_normal_
from models.xcore import get_model
class HyperNet(super_core.SuperModule):
"""The hyper-network."""
def __init__(
self,
shape_container,
layer_embeding,
task_embedding,
num_tasks,
return_container=True,
):
super(HyperNet, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)),
)
self.register_parameter(
"_super_task_embed",
torch.nn.Parameter(torch.Tensor(num_tasks, task_embedding)),
)
trunc_normal_(self._super_layer_embed, std=0.02)
trunc_normal_(self._super_task_embed, std=0.02)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_embeding + task_embedding,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_embeding + task_embedding) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=0.2,
)
self._generator = get_model(**model_kwargs)
self._return_container = return_container
print("generator: {:}".format(self._generator))
def forward_raw(self, task_embed_id):
layer_embed = self._super_layer_embed
task_embed = (
self._super_task_embed[task_embed_id]
.view(1, -1)
.expand(self._num_layers, -1)
)
joint_embed = torch.cat((task_embed, layer_embed), dim=-1)
weights = self._generator(joint_embed)
if self._return_container:
weights = torch.split(weights, 1)
return self._shape_container.translate(weights)
else:
return weights
def forward_candidate(self, input):
raise NotImplementedError
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape))
class HyperNet_VX(super_core.SuperModule):
def __init__(self, shape_container, input_embeding, return_container=True):
super(HyperNet_VX, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)),
)
trunc_normal_(self._super_layer_embed, std=0.02)
model_kwargs = dict(
input_dim=input_embeding,
output_dim=max(self._numel_per_layer),
hidden_dim=input_embeding * 4,
act_cls="sigmoid",
norm_cls="identity",
)
self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs)
self._return_container = return_container
print("generator: {:}".format(self._generator))
def forward_raw(self, input):
weights = self._generator(self._super_layer_embed)
if self._return_container:
weights = torch.split(weights, 1)
return self._shape_container.translate(weights)
else:
return weights
def forward_candidate(self, input):
raise NotImplementedError
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape))

64
exps/GeMOSA/lfna_utils.py Normal file
View File

@@ -0,0 +1,64 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import copy
import torch
from tqdm import tqdm
from xautodl.procedures import prepare_seed, prepare_logger
from xautodl.datasets.synthetic_core import get_synthetic_env
def lfna_setup(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=1,
output_dim=1,
hidden_dims=[args.hidden_dim] * 2,
act_cls="gelu",
norm_cls="layer_norm_1d",
)
return logger, model_kwargs
def train_model(model, dataset, lr, epochs):
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=True)
best_loss, best_param = None, None
for _iepoch in range(epochs):
preds = model(dataset.x)
optimizer.zero_grad()
loss = criterion(preds, dataset.y)
loss.backward()
optimizer.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
return best_loss
class TimeData:
def __init__(self, timestamp, xs, ys):
self._timestamp = timestamp
self._xs = xs
self._ys = ys
@property
def x(self):
return self._xs
@property
def y(self):
return self._ys
@property
def timestamp(self):
return self._timestamp
def __repr__(self):
return "{name}(timestamp={timestamp}, with {num} samples)".format(
name=self.__class__.__name__, timestamp=self._timestamp, num=len(self._xs)
)

343
exps/GeMOSA/main.py Normal file
View File

@@ -0,0 +1,343 @@
#####################################################
# Learning to Generate Model One Step Ahead #
#####################################################
# python exps/GeMOSA/lfna.py --env_version v1 --workers 0
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
from torch.nn import functional as F
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core, trunc_normal_
from lfna_utils import lfna_setup, train_model, TimeData
from lfna_meta_model import MetaModelV1
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
logger.log("Online evaluate: {:}".format(env))
loss_meter = AverageMeter()
w_containers = dict()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
with torch.no_grad():
meta_model.eval()
base_model.eval()
_, [future_container], time_embeds = meta_model(
future_time.to(args.device).view(1, 1), None, False
)
if save:
w_containers[idx] = future_container.no_grad_clone()
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
loss_meter.update(future_loss.item())
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
args.refine_lr,
args.refine_epochs,
{"param": time_embeds, "loss": future_loss.item()},
)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
idx, len(env), future_loss.item()
)
+ ", post-loss={:.4f}".format(post_refine_loss if refine else -1)
)
meta_model.clear_fixed()
meta_model.clear_learnt()
return w_containers, loss_meter
def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
base_model.train()
meta_model.train()
optimizer = torch.optim.Adam(
meta_model.get_parameters(True, True, True),
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=True,
)
logger.log("Pre-train the meta-model")
logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
if meta_model.has_best(final_best_name):
meta_model.load_best(final_best_name)
logger.log("Directly load the best model from {:}".format(final_best_name))
return
total_indexes = list(range(meta_model.meta_length))
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
per_epoch_time, start_time = AverageMeter(), time.time()
device = args.device
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
optimizer.zero_grad()
generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True)
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
raw_time_steps = meta_model.meta_timestamps[batch_indexes]
regularization_loss = F.l1_loss(
generated_time_embeds, meta_model.super_meta_embed, reduction="mean"
)
# future loss
total_future_losses, total_present_losses = [], []
_, future_containers, _ = meta_model(
None, generated_time_embeds[batch_indexes], False
)
_, present_containers, _ = meta_model(
None, meta_model.super_meta_embed[batch_indexes], False
)
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
_, (inputs, targets) = xenv(time_step)
inputs, targets = inputs.to(device), targets.to(device)
predictions = base_model.forward_with_container(
inputs, future_containers[ibatch]
)
total_future_losses.append(criterion(predictions, targets))
predictions = base_model.forward_with_container(
inputs, present_containers[ibatch]
)
total_present_losses.append(criterion(predictions, targets))
with torch.no_grad():
meta_std = torch.stack(total_future_losses).std().item()
loss_future = torch.stack(total_future_losses).mean()
loss_present = torch.stack(total_present_losses).mean()
total_loss = loss_future + loss_present + regularization_loss
total_loss.backward()
optimizer.step()
# success
success, best_score = meta_model.save_best(-total_loss.item())
logger.log(
"{:} [META {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format(
time_string(),
iepoch,
args.epochs,
total_loss.item(),
meta_std,
loss_future.item(),
loss_present.item(),
regularization_loss.item(),
)
+ ", batch={:}".format(len(total_future_losses))
+ ", success={:}, best={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh)
+ ", {:}".format(left_time)
)
if success:
last_success_epoch = iepoch
if iepoch - last_success_epoch >= early_stop_thresh:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
meta_model.load_best()
# save to the final model
meta_model.set_best_name(final_best_name)
success, _ = meta_model.save_best(best_score + 1e-6)
assert success
logger.log("Save the best model into {:}".format(final_best_name))
def main(args):
logger, model_kwargs = lfna_setup(args)
train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The trainval enviornment: {:}".format(trainval_env))
logger.log("The total enviornment: {:}".format(all_env))
base_model = get_model(**model_kwargs)
base_model = base_model.to(args.device)
criterion = torch.nn.MSELoss()
shape_container = base_model.get_w_container().to_shape_container()
# pre-train the hypernetwork
timestamps = trainval_env.get_timestamp(None)
meta_model = MetaModelV1(
shape_container,
args.layer_dim,
args.time_dim,
timestamps,
seq_length=args.seq_length,
interval=trainval_env.time_interval,
)
meta_model = meta_model.to(args.device)
logger.log("The base-model has {:} weights.".format(base_model.numel()))
logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
meta_train_procedure(base_model, meta_model, criterion, trainval_env, args, logger)
# try to evaluate once
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter))
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(".")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/lfna-battle",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
default=16,
help="The hidden dimension.",
)
parser.add_argument(
"--layer_dim",
type=int,
default=16,
help="The layer chunk dimension.",
)
parser.add_argument(
"--time_dim",
type=int,
default=16,
help="The timestamp dimension.",
)
#####
parser.add_argument(
"--lr",
type=float,
default=0.002,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.00001,
help="The weight decay for the optimizer (default is Adam)",
)
parser.add_argument(
"--meta_batch",
type=int,
default=64,
help="The batch size for the meta-model",
)
parser.add_argument(
"--sampler_enlarge",
type=int,
default=5,
help="Enlarge the #iterations for an epoch",
)
parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.")
parser.add_argument(
"--refine_lr",
type=float,
default=0.001,
help="The learning rate for the optimizer, during refine",
)
parser.add_argument(
"--refine_epochs", type=int, default=150, help="The final refine #epochs."
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=20,
help="The #epochs for early stop.",
)
parser.add_argument(
"--pretrain_early_stop_thresh",
type=int,
default=300,
help="The #epochs for early stop.",
)
parser.add_argument(
"--seq_length", type=int, default=10, help="The sequence length."
)
parser.add_argument(
"--workers", type=int, default=4, help="The number of workers in parallel."
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
args.save_dir,
args.meta_batch,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.seq_length,
args.lr,
args.weight_decay,
args.epochs,
args.env_version,
)
main(args)

View File

@@ -0,0 +1,373 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
############################################################################
# python exps/GMOA/vis-synthetic.py --env_version v1 #
# python exps/GMOA/vis-synthetic.py --env_version v2 #
############################################################################
import os, sys, copy, random
import torch
import numpy as np
import argparse
from collections import OrderedDict, defaultdict
from pathlib import Path
from tqdm import tqdm
from pprint import pprint
import matplotlib
from matplotlib import cm
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.models.xcore import get_model
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.procedures.metric_utils import MSEMetric
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths, label=label)
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=1.5, label=None)
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
save_path = save_dir / "{:04d}".format(timestamp)
# print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))
dpi, width, height = 40, wh[0], wh[1]
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
fig = plt.figure(figsize=figsize)
if fig_title is not None:
fig.suptitle(
fig_title, fontsize=LegendFontsize, fontweight="bold", x=0.5, y=0.92
)
for idx, scatter_dict in enumerate(scatter_list):
cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1)
plot_scatter(
cur_ax,
scatter_dict["xaxis"],
scatter_dict["yaxis"],
scatter_dict["color"],
scatter_dict["alpha"],
scatter_dict["linewidths"],
scatter_dict["label"],
)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
cur_ax.set_xlim(scatter_dict["xlim"][0], scatter_dict["xlim"][1])
cur_ax.set_ylim(scatter_dict["ylim"][0], scatter_dict["ylim"][1])
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
def find_min(cur, others):
if cur is None:
return float(others)
else:
return float(min(cur, others))
def find_max(cur, others):
if cur is None:
return float(others.max())
else:
return float(max(cur, others))
def compare_cl(save_dir):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env, cl_function = create_example_v1(
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
timestamp_config=dict(num=200),
num_per_task=1000,
)
models = dict()
cl_function.set_timestamp(0)
cl_xaxis_min = None
cl_xaxis_max = None
all_data = OrderedDict()
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
xaxis_all = dataset[0][:, 0].numpy()
yaxis_all = dataset[1][:, 0].numpy()
current_data = dict()
current_data["lfna_xaxis_all"] = xaxis_all
current_data["lfna_yaxis_all"] = yaxis_all
# compute cl-min
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std())
all_data[timestamp] = current_data
global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1)
global_cl_yaxis_all = cl_function.noise_call(global_cl_xaxis_all)
for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)):
scatter_list = []
scatter_list.append(
{
"xaxis": xdata["lfna_xaxis_all"],
"yaxis": xdata["lfna_yaxis_all"],
"color": "k",
"linewidths": 15,
"alpha": 0.99,
"xlim": (-6, 6),
"ylim": (-40, 40),
"label": "LFNA",
}
)
cur_cl_xaxis_min = cl_xaxis_min
cur_cl_xaxis_max = cl_xaxis_min + (cl_xaxis_max - cl_xaxis_min) * (
idx + 1
) / len(all_data)
cl_xaxis_all = np.arange(cur_cl_xaxis_min, cur_cl_xaxis_max, step=0.01)
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all, std=0.2)
scatter_list.append(
{
"xaxis": cl_xaxis_all,
"yaxis": cl_yaxis_all,
"color": "k",
"linewidths": 15,
"xlim": (round(cl_xaxis_min, 1), round(cl_xaxis_max, 1)),
"ylim": (-20, 6),
"alpha": 0.99,
"label": "Continual Learning",
}
)
draw_multi_fig(
save_dir,
idx,
scatter_list,
wh=(2200, 1800),
fig_title="Timestamp={:03d}".format(idx),
)
print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve()
base_cmd = (
"ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format(
xdir=save_dir
)
)
video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format(
base_cmd, xdir=save_dir
)
print(video_cmd + "\n")
os.system(video_cmd)
os.system(
"{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir)
)
def visualize_env(save_dir, version):
save_dir = Path(str(save_dir))
for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr
sub_save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env = get_synthetic_env(version=version)
allxs, allys = [], []
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(1, 1, 1)
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
plot_scatter(cur_ax, allx, ally, "k", 0.99, 15, "timestamp={:05d}".format(idx))
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir / "png", version=version
)
print(base_cmd)
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version))
def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
save_dir = Path(str(save_dir))
for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr
sub_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 30, 3200, 2000
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
dynamic_env = get_synthetic_env(mode=None, version=version)
allxs, allys = [], []
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
alg_name2dir = OrderedDict()
# alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
# alg_name2dir["MAML"] = "use-maml-s1"
# alg_name2dir["LFNA (fix init)"] = "lfna-fix-init"
if version == "v1":
# alg_name2dir["Optimal"] = "use-same-timestamp"
alg_name2dir[
"GMOA"
] = "lfna-battle-bs128-d16_16_16-s16-lr0.002-wd1e-05-e10000-envv1"
else:
raise ValueError("Invalid version: {:}".format(version))
alg_name2all_containers = OrderedDict()
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
ckp_path = Path(alg_dir) / str(xdir) / "final-ckp.pth"
xdata = torch.load(ckp_path, map_location="cpu")
alg_name2all_containers[alg] = xdata["w_containers"]
# load the basic model
model = get_model(
dict(model_type="norm_mlp"),
input_dim=1,
output_dim=1,
hidden_dims=[16] * 2,
act_cls="gelu",
norm_cls="layer_norm_1d",
)
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
colors = ["r", "g", "b", "m", "y"]
linewidths, skip = 10, 5
for idx, (timestamp, (ori_allx, ori_ally)) in enumerate(
tqdm(dynamic_env, ncols=50)
):
if idx <= skip:
continue
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(2, 1, 1)
# the data
allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy()
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
with torch.no_grad():
predicts = model.forward_with_container(
ori_allx, alg_name2all_containers[alg][idx]
)
predicts = predicts.cpu()
# keep data
metric = MSEMetric()
metric(predicts, ori_ally)
predicts = predicts.view(-1).numpy()
alg2xs[alg].append(idx)
alg2ys[alg].append(metric.get_info()["mse"])
plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
cur_ax.legend(loc=1, fontsize=LegendFontsize)
# the trajectory data
cur_ax = fig.add_subplot(2, 1, 2)
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
# plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg)
cur_ax.plot(
alg2xs[alg],
alg2ys[alg],
color=colors[idx_alg],
linestyle="-",
linewidth=5,
label=alg,
)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
cur_ax.set_xlabel("Timestamp", fontsize=LabelSize)
cur_ax.set_ylabel("MSE", fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(1, len(dynamic_env))
cur_ax.set_ylim(0, 10)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx - skip)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx - skip)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir / "png", w=width, h=height, ver=version
)
os.system(
"{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)
)
os.system(
"{:} {xdir}/com-alg-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Visualize synthetic data.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/vis-synthetic",
help="The save directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
args = parser.parse_args()
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_cl(os.path.join(args.save_dir, "compare-cl"))