Update LFNA
This commit is contained in:
@@ -107,11 +107,20 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||
base_model.eval()
|
||||
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
|
||||
[seq_containers], _ = meta_model(time_seqs, None)
|
||||
future_container = seq_containers[-2]
|
||||
_, (future_x, future_y) = env(time_seqs[0, -2].item())
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
# For Debug
|
||||
for idx in range(time_seqs.numel()):
|
||||
future_container = seq_containers[idx]
|
||||
_, (future_x, future_y) = env(time_seqs[0, idx].item())
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(
|
||||
future_x, future_container
|
||||
)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
logger.log(
|
||||
"--> time={:.4f} -> loss={:.4f}".format(
|
||||
time_seqs[0, idx].item(), future_loss.item()
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||
idx, len(env), future_loss.item()
|
||||
|
||||
Reference in New Issue
Block a user