Update LFNA

This commit is contained in:
D-X-Y
2021-05-15 16:01:40 +08:00
parent b81ef2dd74
commit 72f240bf0a
12 changed files with 128 additions and 1050 deletions

View File

@@ -157,11 +157,11 @@ def main(args):
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
# meta-training
# meta-test
meta_model.load_best()
eval_env = env_info["dynamic_env"]
w_container_per_epoch = dict()
for idx in range(args.seq_length, env_info["total"]):
for idx in range(args.seq_length, len(eval_env)):
# build-timestamp
future_time = env_info["{:}-timestamp".format(idx)]
time_seqs = []
@@ -176,8 +176,8 @@ def main(args):
future_container = seq_containers[-1]
w_container_per_epoch[idx] = future_container.no_grad_clone()
# evaluation
future_x = env_info["{:}-x".format(idx)]
future_y = env_info["{:}-y".format(idx)]
future_x = env_info["{:}-x".format(idx)].to(args.device)
future_y = env_info["{:}-y".format(idx)].to(args.device)
future_y_hat = base_model.forward_with_container(
future_x, w_container_per_epoch[idx]
)
@@ -299,12 +299,12 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}_{:}_{:}-e{:}".format(
args.save_dir = "{:}-d{:}_{:}_{:}-e{:}-env{:}".format(
args.save_dir,
args.env_version,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.epochs,
args.env_version,
)
main(args)