Update LFNA

This commit is contained in:
D-X-Y
2021-05-23 08:21:31 +00:00
parent 2a864ae705
commit 25dc78a7ce
5 changed files with 152 additions and 20 deletions

View File

@@ -107,11 +107,20 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
base_model.eval()
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
[seq_containers], _ = meta_model(time_seqs, None)
future_container = seq_containers[-2]
_, (future_x, future_y) = env(time_seqs[0, -2].item())
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
# For Debug
for idx in range(time_seqs.numel()):
future_container = seq_containers[idx]
_, (future_x, future_y) = env(time_seqs[0, idx].item())
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(
future_x, future_container
)
future_loss = criterion(future_y_hat, future_y)
logger.log(
"--> time={:.4f} -> loss={:.4f}".format(
time_seqs[0, idx].item(), future_loss.item()
)
)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
idx, len(env), future_loss.item()