Update LFNA version 1.0

This commit is contained in:
D-X-Y
2021-05-07 14:27:15 +08:00
parent 80aaac4dfa
commit 34560ad8d1
5 changed files with 120 additions and 40 deletions

View File

@@ -58,6 +58,8 @@ def main(args):
)
)
w_container_per_epoch = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for i, idx in enumerate(to_evaluate_indexes):
@@ -73,7 +75,6 @@ def main(args):
+ need_time
)
# train the same data
assert idx != 0
historical_x = env_info["{:}-x".format(idx)]
historical_y = env_info["{:}-y".format(idx)]
# build model
@@ -82,9 +83,10 @@ def main(args):
input_dim=1,
output_dim=1,
act_cls="leaky_relu",
norm_cls="simple_norm",
mean=mean,
std=std,
norm_cls="identity",
# norm_cls="simple_norm",
# mean=mean,
# std=std,
)
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
# build optimizer
@@ -137,6 +139,7 @@ def main(args):
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
idx, env_info["total"]
)
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
@@ -151,6 +154,11 @@ def main(args):
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()