Complete LFNA 1.0

This commit is contained in:
D-X-Y
2021-05-14 00:36:37 +08:00
parent c2fa181bc5
commit b81ef2dd74
4 changed files with 311 additions and 19 deletions

View File

@@ -1,6 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/lfna.py --env_version v1 --workers 0
# python exps/LFNA/lfna.py --env_version v1 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
@@ -156,19 +157,61 @@ def main(args):
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
# meta-training
meta_model.load_best()
eval_env = env_info["dynamic_env"]
w_container_per_epoch = dict()
for idx in range(0, total_bar):
for idx in range(args.seq_length, env_info["total"]):
# build-timestamp
future_time = env_info["{:}-timestamp".format(idx)]
future_x = env_info["{:}-x".format(idx)]
future_y = env_info["{:}-y".format(idx)]
future_container = hypernet(task_embeds[idx])
w_container_per_epoch[idx] = future_container.no_grad_clone()
time_seqs = []
for iseq in range(args.seq_length):
time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
time_seqs.reverse()
with torch.no_grad():
future_y_hat = model.forward_with_container(
meta_model.eval()
base_model.eval()
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
[seq_containers] = meta_model(time_seqs)
future_container = seq_containers[-1]
w_container_per_epoch[idx] = future_container.no_grad_clone()
# evaluation
future_x = env_info["{:}-x".format(idx)]
future_y = env_info["{:}-y".format(idx)]
future_y_hat = base_model.forward_with_container(
future_x, w_container_per_epoch[idx]
)
future_loss = criterion(future_y_hat, future_y)
logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()))
logger.log(
"meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())
)
# creating the new meta-time-embedding
distance = meta_model.get_closest_meta_distance(future_time)
if distance < eval_env.timestamp_interval:
continue
#
new_param = meta_model.create_meta_embed()
optimizer = torch.optim.Adam(
[new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True
)
meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param)
meta_model.eval()
base_model.train()
for iepoch in range(args.epochs):
optimizer.zero_grad()
[seq_containers] = meta_model(time_seqs)
future_container = seq_containers[-1]
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
future_loss.backward()
optimizer.step()
logger.log(
"post-meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())
)
with torch.no_grad():
meta_model.replace_append_learnt(None, None)
meta_model.append_fixed(torch.Tensor([future_time]), new_param)
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
@@ -216,7 +259,7 @@ if __name__ == "__main__":
parser.add_argument(
"--init_lr",
type=float,
default=0.01,
default=0.005,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
@@ -235,7 +278,7 @@ if __name__ == "__main__":
parser.add_argument(
"--early_stop_thresh",
type=int,
default=50,
default=25,
help="The maximum epochs for early stop.",
)
parser.add_argument(
@@ -256,7 +299,12 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}_{:}_{:}".format(
args.save_dir, args.env_version, args.hidden_dim, args.layer_dim, args.time_dim
args.save_dir = "{:}-{:}-d{:}_{:}_{:}-e{:}".format(
args.save_dir,
args.env_version,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.epochs,
)
main(args)