Complete LFNA 1.0
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
# python exps/LFNA/lfna.py --env_version v1 --workers 0
|
||||
# python exps/LFNA/lfna.py --env_version v1 --device cuda
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
@@ -156,19 +157,61 @@ def main(args):
|
||||
per_epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
# meta-training
|
||||
meta_model.load_best()
|
||||
eval_env = env_info["dynamic_env"]
|
||||
w_container_per_epoch = dict()
|
||||
for idx in range(0, total_bar):
|
||||
for idx in range(args.seq_length, env_info["total"]):
|
||||
# build-timestamp
|
||||
future_time = env_info["{:}-timestamp".format(idx)]
|
||||
future_x = env_info["{:}-x".format(idx)]
|
||||
future_y = env_info["{:}-y".format(idx)]
|
||||
future_container = hypernet(task_embeds[idx])
|
||||
w_container_per_epoch[idx] = future_container.no_grad_clone()
|
||||
time_seqs = []
|
||||
for iseq in range(args.seq_length):
|
||||
time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
|
||||
time_seqs.reverse()
|
||||
with torch.no_grad():
|
||||
future_y_hat = model.forward_with_container(
|
||||
meta_model.eval()
|
||||
base_model.eval()
|
||||
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
|
||||
[seq_containers] = meta_model(time_seqs)
|
||||
future_container = seq_containers[-1]
|
||||
w_container_per_epoch[idx] = future_container.no_grad_clone()
|
||||
# evaluation
|
||||
future_x = env_info["{:}-x".format(idx)]
|
||||
future_y = env_info["{:}-y".format(idx)]
|
||||
future_y_hat = base_model.forward_with_container(
|
||||
future_x, w_container_per_epoch[idx]
|
||||
)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()))
|
||||
logger.log(
|
||||
"meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())
|
||||
)
|
||||
|
||||
# creating the new meta-time-embedding
|
||||
distance = meta_model.get_closest_meta_distance(future_time)
|
||||
if distance < eval_env.timestamp_interval:
|
||||
continue
|
||||
#
|
||||
new_param = meta_model.create_meta_embed()
|
||||
optimizer = torch.optim.Adam(
|
||||
[new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True
|
||||
)
|
||||
meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param)
|
||||
meta_model.eval()
|
||||
base_model.train()
|
||||
for iepoch in range(args.epochs):
|
||||
optimizer.zero_grad()
|
||||
[seq_containers] = meta_model(time_seqs)
|
||||
future_container = seq_containers[-1]
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
future_loss.backward()
|
||||
optimizer.step()
|
||||
logger.log(
|
||||
"post-meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())
|
||||
)
|
||||
with torch.no_grad():
|
||||
meta_model.replace_append_learnt(None, None)
|
||||
meta_model.append_fixed(torch.Tensor([future_time]), new_param)
|
||||
|
||||
save_checkpoint(
|
||||
{"w_container_per_epoch": w_container_per_epoch},
|
||||
@@ -216,7 +259,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--init_lr",
|
||||
type=float,
|
||||
default=0.01,
|
||||
default=0.005,
|
||||
help="The initial learning rate for the optimizer (default is Adam)",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -235,7 +278,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--early_stop_thresh",
|
||||
type=int,
|
||||
default=50,
|
||||
default=25,
|
||||
help="The maximum epochs for early stop.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -256,7 +299,12 @@ if __name__ == "__main__":
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||
args.save_dir = "{:}-{:}-d{:}_{:}_{:}".format(
|
||||
args.save_dir, args.env_version, args.hidden_dim, args.layer_dim, args.time_dim
|
||||
args.save_dir = "{:}-{:}-d{:}_{:}_{:}-e{:}".format(
|
||||
args.save_dir,
|
||||
args.env_version,
|
||||
args.hidden_dim,
|
||||
args.layer_dim,
|
||||
args.time_dim,
|
||||
args.epochs,
|
||||
)
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user