Update LFNA

This commit is contained in:
D-X-Y
2021-05-15 16:01:40 +08:00
parent b81ef2dd74
commit 72f240bf0a
12 changed files with 128 additions and 1050 deletions

View File

@@ -1,7 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-maml.py --env_version v1 --hidden_dim 16 --inner_step 5
# python exps/LFNA/basic-maml.py --env_version v1 --inner_step 5
# python exps/LFNA/basic-maml.py --env_version v2
#####################################################
import sys, time, copy, torch, random, argparse
@@ -20,7 +20,7 @@ from utils import split_str2indexes
from procedures.advanced_main import basic_train_fn, basic_eval_fn
from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from datasets.synthetic_core import get_synthetic_env
from datasets.synthetic_core import get_synthetic_env, EnvSampler
from models.xcore import get_model
from xlayers import super_core
@@ -42,11 +42,10 @@ class MAML:
self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.meta_optimizer,
milestones=[
int(epochs * 0.25),
int(epochs * 0.5),
int(epochs * 0.75),
int(epochs * 0.8),
int(epochs * 0.9),
],
gamma=0.3,
gamma=0.1,
)
self.inner_lr = inner_lr
self.inner_step = inner_step
@@ -85,33 +84,27 @@ class MAML:
self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"])
self.meta_lr_scheduler.load_state_dict(state_dict["meta_lr_scheduler"])
def save_best(self, iepoch, score):
if self._best_info["score"] is None or self._best_info["score"] < score:
state_dict = dict(
criterion=self.criterion.state_dict(),
network=self.network.state_dict(),
meta_optimizer=self.meta_optimizer.state_dict(),
meta_lr_scheduler=self.meta_lr_scheduler.state_dict(),
)
self._best_info["state_dict"] = state_dict
self._best_info["score"] = score
self._best_info["iepoch"] = iepoch
is_best = True
else:
is_best = False
return self._best_info, is_best
def state_dict(self):
state_dict = dict()
state_dict["criterion"] = self.criterion.state_dict()
state_dict["network"] = self.network.state_dict()
state_dict["meta_optimizer"] = self.meta_optimizer.state_dict()
state_dict["meta_lr_scheduler"] = self.meta_lr_scheduler.state_dict()
return state_dict
def save_best(self, score):
success, best_score = self.network.save_best(score)
return success, best_score
def load_best(self):
self.network.load_best()
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
model = get_model(**model_kwargs)
total_time = env_info["total"]
for i in range(total_time):
for xkey in ("timestamp", "x", "y"):
nkey = "{:}-{:}".format(i, xkey)
assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys()))
train_time_bar = total_time // 2
dynamic_env = get_synthetic_env(mode="train", version=args.env_version)
criterion = torch.nn.MSELoss()
@@ -120,83 +113,65 @@ def main(args):
)
# meta-training
last_success_epoch = 0
per_epoch_time, start_time = AverageMeter(), time.time()
# for iepoch in range(args.epochs):
iepoch = 0
while iepoch < args.epochs:
for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
logger.log(
head_str = (
"[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs)
+ need_time
)
maml.zero_grad()
batch_indexes, meta_losses = [], []
meta_losses = []
for ibatch in range(args.meta_batch):
sampled_timestamp = random.randint(0, train_time_bar)
batch_indexes.append("{:5d}".format(sampled_timestamp))
past_dataset = TimeData(
sampled_timestamp,
env_info["{:}-x".format(sampled_timestamp)],
env_info["{:}-y".format(sampled_timestamp)],
future_timestamp = dynamic_env.random_timestamp()
_, (future_x, future_y) = dynamic_env(future_timestamp)
past_timestamp = (
future_timestamp - args.prev_time * dynamic_env.timestamp_interval
)
future_dataset = TimeData(
sampled_timestamp + 1,
env_info["{:}-x".format(sampled_timestamp + 1)],
env_info["{:}-y".format(sampled_timestamp + 1)],
)
future_container = maml.adapt(past_dataset)
future_y_hat = maml.predict(future_dataset.x, future_container)
future_loss = maml.criterion(future_y_hat, future_dataset.y)
_, (past_x, past_y) = dynamic_env(past_timestamp)
future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y))
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
meta_loss = torch.stack(meta_losses).mean()
meta_loss.backward()
maml.step()
logger.log(
"meta-loss: {:.4f} batch: {:}".format(
meta_loss.item(), ",".join(batch_indexes)
)
)
best_info, is_best = maml.save_best(iepoch, -meta_loss.item())
if is_best:
save_checkpoint(best_info, logger.path("best"), logger)
logger.log("Save the best into {:}".format(logger.path("best")))
if iepoch >= 10 and (
torch.isnan(meta_loss).item() or meta_loss.item() >= args.fail_thresh
):
xdata = torch.load(logger.path("best"))
maml.load_state_dict(xdata["state_dict"])
iepoch = xdata["iepoch"]
logger.log(
"The training failed, re-use the previous best epoch [{:}]".format(
iepoch
)
)
else:
iepoch = iepoch + 1
logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item()))
success, best_score = maml.save_best(-meta_loss.item())
if success:
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
save_checkpoint(maml.state_dict(), logger.path("model"), logger)
last_success_epoch = iepoch
if iepoch - last_success_epoch >= args.early_stop_thresh:
logger.log("Early stop at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
# meta-test
maml.load_best()
eval_env = env_info["dynamic_env"]
assert eval_env.timestamp_interval == dynamic_env.timestamp_interval
w_container_per_epoch = dict()
for idx in range(1, env_info["total"]):
past_dataset = TimeData(
idx - 1,
env_info["{:}-x".format(idx - 1)],
env_info["{:}-y".format(idx - 1)],
for idx in range(args.prev_time, len(eval_env)):
future_timestamp, (future_x, future_y) = eval_env[idx]
past_timestamp = (
future_timestamp.item() - args.prev_time * eval_env.timestamp_interval
)
current_container = maml.adapt(past_dataset)
w_container_per_epoch[idx] = current_container.no_grad_clone()
_, (past_x, past_y) = eval_env(past_timestamp)
future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y))
w_container_per_epoch[idx] = future_container.no_grad_clone()
with torch.no_grad():
current_x = env_info["{:}-x".format(idx)]
current_y = env_info["{:}-y".format(idx)]
current_y_hat = maml.predict(current_x, w_container_per_epoch[idx])
current_loss = maml.criterion(current_y_hat, current_y)
logger.log(
"meta-test: [{:03d}] -> loss={:.4f}".format(idx, current_loss.item())
)
future_y_hat = maml.predict(future_x, w_container_per_epoch[idx])
future_loss = maml.criterion(future_y_hat, future_y)
logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()))
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
logger.path(None) / "final-ckp.pth",
@@ -224,13 +199,13 @@ if __name__ == "__main__":
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
default=16,
help="The hidden dimension.",
)
parser.add_argument(
"--meta_lr",
type=float,
default=0.05,
default=0.01,
help="The learning rate for the MAML optimizer (default is Adam)",
)
parser.add_argument(
@@ -242,24 +217,36 @@ if __name__ == "__main__":
parser.add_argument(
"--inner_lr",
type=float,
default=0.01,
default=0.005,
help="The learning rate for the inner optimization",
)
parser.add_argument(
"--inner_step", type=int, default=1, help="The inner loop steps for MAML."
)
parser.add_argument(
"--prev_time",
type=int,
default=5,
help="The gap between prev_time and current_timestamp",
)
parser.add_argument(
"--meta_batch",
type=int,
default=10,
default=64,
help="The batch size for the meta-model",
)
parser.add_argument(
"--epochs",
type=int,
default=1000,
default=2000,
help="The total number of epochs.",
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=50,
help="The maximum epochs for early stop.",
)
parser.add_argument(
"--workers",
type=int,
@@ -272,7 +259,13 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-s{:}-{:}-d{:}".format(
args.save_dir, args.inner_step, args.env_version, args.hidden_dim
args.save_dir = "{:}-s{:}-mlr{:}-d{:}-prev{:}-e{:}-env{:}".format(
args.save_dir,
args.inner_step,
args.meta_lr,
args.hidden_dim,
args.prev_time,
args.epochs,
args.env_version,
)
main(args)