Re-organize GeMOSA
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
#####################################################
|
||||
# Learning to Generate Model One Step Ahead #
|
||||
#####################################################
|
||||
# python exps/GeMOSA/lfna.py --env_version v1 --workers 0
|
||||
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
|
||||
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
|
||||
# python exps/GeMOSA/main.py --env_version v1 --workers 0
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData
|
||||
from meta_model import MetaModelV1
|
||||
|
||||
|
||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
|
||||
def online_evaluate(
|
||||
env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False
|
||||
):
|
||||
logger.log("Online evaluate: {:}".format(env))
|
||||
loss_meter = AverageMeter()
|
||||
w_containers = dict()
|
||||
@@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
|
||||
with torch.no_grad():
|
||||
meta_model.eval()
|
||||
base_model.eval()
|
||||
[future_container], time_embeds = meta_model(
|
||||
future_time.to(args.device).view(-1), None, False
|
||||
future_time_embed = meta_model.gen_time_embed(
|
||||
future_time.to(args.device).view(-1)
|
||||
)
|
||||
[future_container] = meta_model.gen_model(future_time_embed)
|
||||
if save:
|
||||
w_containers[idx] = future_container.no_grad_clone()
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
loss_meter.update(future_loss.item())
|
||||
refine, post_refine_loss = meta_model.adapt(
|
||||
base_model,
|
||||
criterion,
|
||||
future_time.item(),
|
||||
future_x,
|
||||
future_y,
|
||||
args.refine_lr,
|
||||
args.refine_epochs,
|
||||
{"param": time_embeds, "loss": future_loss.item()},
|
||||
)
|
||||
if easy_adapt:
|
||||
meta_model.easy_adapt(future_time.item(), future_time_embed)
|
||||
refine, post_refine_loss = False, -1
|
||||
else:
|
||||
refine, post_refine_loss = meta_model.adapt(
|
||||
base_model,
|
||||
criterion,
|
||||
future_time.item(),
|
||||
future_x,
|
||||
future_y,
|
||||
args.refine_lr,
|
||||
args.refine_epochs,
|
||||
{"param": future_time_embed, "loss": future_loss.item()},
|
||||
)
|
||||
logger.log(
|
||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||
idx, len(env), future_loss.item()
|
||||
@@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
|
||||
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
|
||||
generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps)
|
||||
|
||||
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
|
||||
|
||||
@@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
||||
)
|
||||
# future loss
|
||||
total_future_losses, total_present_losses = [], []
|
||||
future_containers, _ = meta_model(
|
||||
None, generated_time_embeds[batch_indexes], False
|
||||
)
|
||||
present_containers, _ = meta_model(
|
||||
None, meta_model.super_meta_embed[batch_indexes], False
|
||||
future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes])
|
||||
present_containers = meta_model.gen_model(
|
||||
meta_model.super_meta_embed[batch_indexes]
|
||||
)
|
||||
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
|
||||
_, (inputs, targets) = xenv(time_step)
|
||||
@@ -216,13 +220,34 @@ def main(args):
|
||||
# try to evaluate once
|
||||
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
||||
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
||||
"""
|
||||
w_containers, loss_meter = online_evaluate(
|
||||
all_env, meta_model, base_model, criterion, args, logger, True
|
||||
)
|
||||
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
||||
"""
|
||||
_, test_loss_meter_adapt_v1 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, args, logger, False, False
|
||||
)
|
||||
_, test_loss_meter_adapt_v2 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, args, logger, False, True
|
||||
)
|
||||
logger.log(
|
||||
"In the online test enviornment, the total loss for refine-adapt is {:}".format(
|
||||
test_loss_meter_adapt_v1
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
"In the online test enviornment, the total loss for easy-adapt is {:}".format(
|
||||
test_loss_meter_adapt_v2
|
||||
)
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
{"all_w_containers": w_containers},
|
||||
{
|
||||
"test_loss_adapt_v1": test_loss_meter_adapt_v1.avg,
|
||||
"test_loss_adapt_v2": test_loss_meter_adapt_v2.avg,
|
||||
},
|
||||
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
|
||||
logger,
|
||||
)
|
||||
|
Reference in New Issue
Block a user