This commit is contained in:
D-X-Y
2021-05-26 04:47:38 +00:00
parent 5eab0de53e
commit d557c328a8
4 changed files with 23 additions and 27 deletions

View File

@@ -46,8 +46,8 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
with torch.no_grad():
meta_model.eval()
base_model.eval()
_, [future_container], time_embeds = meta_model(
future_time.to(args.device).view(1, 1), None, False
[future_container], time_embeds = meta_model(
future_time.to(args.device).view(-1), None, False
)
if save:
w_containers[idx] = future_container.no_grad_clone()
@@ -117,10 +117,10 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
)
# future loss
total_future_losses, total_present_losses = [], []
_, future_containers, _ = meta_model(
future_containers, _ = meta_model(
None, generated_time_embeds[batch_indexes], False
)
_, present_containers, _ = meta_model(
present_containers, _ = meta_model(
None, meta_model.super_meta_embed[batch_indexes], False
)
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):