Try a different model / LFNA V3

This commit is contained in:
D-X-Y
2021-05-24 01:06:22 +08:00
parent be274e0b6c
commit 63a0361152
2 changed files with 73 additions and 29 deletions

View File

@@ -5,7 +5,7 @@
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128
#####################################################
import sys, time, copy, torch, random, argparse
import pdb, sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
@@ -95,19 +95,13 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
def online_evaluate(env, meta_model, base_model, criterion, args, logger):
logger.log("Online evaluate: {:}".format(env))
for idx, (timestamp, (future_x, future_y)) in enumerate(env):
future_time = timestamp.item()
time_seqs = [
future_time - iseq * env.timestamp_interval
for iseq in range(args.seq_length)
]
time_seqs.reverse()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
with torch.no_grad():
meta_model.eval()
base_model.eval()
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
[seq_containers], _ = meta_model(time_seqs, None)
future_container = seq_containers[-1]
_, [future_container], _ = meta_model(
future_time.to(args.device).view(1, 1), None, True
)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
@@ -116,18 +110,17 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
idx, len(env), future_loss.item()
)
)
meta_model.adapt(
future_time,
refine = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
env.timestamp_interval,
args.refine_lr,
args.refine_epochs,
)
import pdb
pdb.set_trace()
print("-")
meta_model.clear_fixed()
meta_model.clear_learnt()
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
@@ -251,7 +244,7 @@ def main(args):
logger.log("The meta-model is\n{:}".format(meta_model))
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
train_env.reset_max_seq_length(args.seq_length)
# train_env.reset_max_seq_length(args.seq_length)
# valid_env.reset_max_seq_length(args.seq_length)
valid_env_loader = torch.utils.data.DataLoader(
valid_env,
@@ -269,8 +262,8 @@ def main(args):
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
# try to evaluate once
online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
import pdb
pdb.set_trace()
optimizer = torch.optim.Adam(
@@ -510,11 +503,11 @@ if __name__ == "__main__":
parser.add_argument(
"--refine_lr",
type=float,
default=0.001,
default=0.002,
help="The learning rate for the optimizer, during refine",
)
parser.add_argument(
"--refine_epochs", type=int, default=1000, help="The final refine #epochs."
"--refine_epochs", type=int, default=50, help="The final refine #epochs."
)
parser.add_argument(
"--early_stop_thresh",