LFNA ok on the valid data
This commit is contained in:
@@ -99,18 +99,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||
with torch.no_grad():
|
||||
meta_model.eval()
|
||||
base_model.eval()
|
||||
_, [future_container], _ = meta_model(
|
||||
_, [future_container], time_embeds = meta_model(
|
||||
future_time.to(args.device).view(1, 1), None, True
|
||||
)
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
logger.log(
|
||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||
idx, len(env), future_loss.item()
|
||||
)
|
||||
)
|
||||
refine = meta_model.adapt(
|
||||
refine, post_refine_loss = meta_model.adapt(
|
||||
base_model,
|
||||
criterion,
|
||||
future_time.item(),
|
||||
@@ -118,6 +113,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||
future_y,
|
||||
args.refine_lr,
|
||||
args.refine_epochs,
|
||||
{"param": time_embeds, "loss": future_loss.item()},
|
||||
)
|
||||
logger.log(
|
||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||
idx, len(env), future_loss.item()
|
||||
)
|
||||
+ ", post-loss={:.4f}".format(post_refine_loss if refine else -1)
|
||||
)
|
||||
meta_model.clear_fixed()
|
||||
meta_model.clear_learnt()
|
||||
@@ -244,21 +246,6 @@ def main(args):
|
||||
logger.log("The meta-model is\n{:}".format(meta_model))
|
||||
|
||||
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
||||
# train_env.reset_max_seq_length(args.seq_length)
|
||||
# valid_env.reset_max_seq_length(args.seq_length)
|
||||
valid_env_loader = torch.utils.data.DataLoader(
|
||||
valid_env,
|
||||
batch_size=args.meta_batch,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
train_env_loader = torch.utils.data.DataLoader(
|
||||
train_env,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=args.workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||
|
||||
# try to evaluate once
|
||||
@@ -507,7 +494,7 @@ if __name__ == "__main__":
|
||||
help="The learning rate for the optimizer, during refine",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--refine_epochs", type=int, default=50, help="The final refine #epochs."
|
||||
"--refine_epochs", type=int, default=40, help="The final refine #epochs."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stop_thresh",
|
||||
|
||||
Reference in New Issue
Block a user