Update LFNA
This commit is contained in:
@@ -157,11 +157,11 @@ def main(args):
|
||||
per_epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
# meta-training
|
||||
# meta-test
|
||||
meta_model.load_best()
|
||||
eval_env = env_info["dynamic_env"]
|
||||
w_container_per_epoch = dict()
|
||||
for idx in range(args.seq_length, env_info["total"]):
|
||||
for idx in range(args.seq_length, len(eval_env)):
|
||||
# build-timestamp
|
||||
future_time = env_info["{:}-timestamp".format(idx)]
|
||||
time_seqs = []
|
||||
@@ -176,8 +176,8 @@ def main(args):
|
||||
future_container = seq_containers[-1]
|
||||
w_container_per_epoch[idx] = future_container.no_grad_clone()
|
||||
# evaluation
|
||||
future_x = env_info["{:}-x".format(idx)]
|
||||
future_y = env_info["{:}-y".format(idx)]
|
||||
future_x = env_info["{:}-x".format(idx)].to(args.device)
|
||||
future_y = env_info["{:}-y".format(idx)].to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(
|
||||
future_x, w_container_per_epoch[idx]
|
||||
)
|
||||
@@ -299,12 +299,12 @@ if __name__ == "__main__":
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||
args.save_dir = "{:}-{:}-d{:}_{:}_{:}-e{:}".format(
|
||||
args.save_dir = "{:}-d{:}_{:}_{:}-e{:}-env{:}".format(
|
||||
args.save_dir,
|
||||
args.env_version,
|
||||
args.hidden_dim,
|
||||
args.layer_dim,
|
||||
args.time_dim,
|
||||
args.epochs,
|
||||
args.env_version,
|
||||
)
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user