Update LFNA version 1.0
This commit is contained in:
@@ -58,6 +58,8 @@ def main(args):
|
||||
)
|
||||
)
|
||||
|
||||
w_container_per_epoch = dict()
|
||||
|
||||
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||
for i, idx in enumerate(to_evaluate_indexes):
|
||||
|
||||
@@ -73,7 +75,6 @@ def main(args):
|
||||
+ need_time
|
||||
)
|
||||
# train the same data
|
||||
assert idx != 0
|
||||
historical_x = env_info["{:}-x".format(idx)]
|
||||
historical_y = env_info["{:}-y".format(idx)]
|
||||
# build model
|
||||
@@ -82,9 +83,10 @@ def main(args):
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="simple_norm",
|
||||
mean=mean,
|
||||
std=std,
|
||||
norm_cls="identity",
|
||||
# norm_cls="simple_norm",
|
||||
# mean=mean,
|
||||
# std=std,
|
||||
)
|
||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||
# build optimizer
|
||||
@@ -137,6 +139,7 @@ def main(args):
|
||||
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
|
||||
idx, env_info["total"]
|
||||
)
|
||||
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
|
||||
save_checkpoint(
|
||||
{
|
||||
"model_state_dict": model.state_dict(),
|
||||
@@ -151,6 +154,11 @@ def main(args):
|
||||
|
||||
per_timestamp_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
save_checkpoint(
|
||||
{"w_container_per_epoch": w_container_per_epoch},
|
||||
logger.path(None) / "final-ckp.pth",
|
||||
logger,
|
||||
)
|
||||
|
||||
logger.log("-" * 200 + "\n")
|
||||
logger.close()
|
||||
|
Reference in New Issue
Block a user