Re-organize GeMOSA

This commit is contained in:
D-X-Y
2021-05-27 15:44:01 +08:00
parent 8961215416
commit 6da60664f5
10 changed files with 354 additions and 350 deletions

View File

@@ -1,10 +1,9 @@
#####################################################
# Learning to Generate Model One Step Ahead #
#####################################################
# python exps/GeMOSA/lfna.py --env_version v1 --workers 0
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
# python exps/GeMOSA/main.py --env_version v1 --workers 0
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData
from meta_model import MetaModelV1
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
def online_evaluate(
env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False
):
logger.log("Online evaluate: {:}".format(env))
loss_meter = AverageMeter()
w_containers = dict()
@@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
with torch.no_grad():
meta_model.eval()
base_model.eval()
[future_container], time_embeds = meta_model(
future_time.to(args.device).view(-1), None, False
future_time_embed = meta_model.gen_time_embed(
future_time.to(args.device).view(-1)
)
[future_container] = meta_model.gen_model(future_time_embed)
if save:
w_containers[idx] = future_container.no_grad_clone()
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
loss_meter.update(future_loss.item())
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
args.refine_lr,
args.refine_epochs,
{"param": time_embeds, "loss": future_loss.item()},
)
if easy_adapt:
meta_model.easy_adapt(future_time.item(), future_time_embed)
refine, post_refine_loss = False, -1
else:
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
args.refine_lr,
args.refine_epochs,
{"param": future_time_embed, "loss": future_loss.item()},
)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
idx, len(env), future_loss.item()
@@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
)
optimizer.zero_grad()
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps)
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
@@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
)
# future loss
total_future_losses, total_present_losses = [], []
future_containers, _ = meta_model(
None, generated_time_embeds[batch_indexes], False
)
present_containers, _ = meta_model(
None, meta_model.super_meta_embed[batch_indexes], False
future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes])
present_containers = meta_model.gen_model(
meta_model.super_meta_embed[batch_indexes]
)
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
_, (inputs, targets) = xenv(time_step)
@@ -216,13 +220,34 @@ def main(args):
# try to evaluate once
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
"""
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
"""
_, test_loss_meter_adapt_v1 = online_evaluate(
valid_env, meta_model, base_model, criterion, args, logger, False, False
)
_, test_loss_meter_adapt_v2 = online_evaluate(
valid_env, meta_model, base_model, criterion, args, logger, False, True
)
logger.log(
"In the online test enviornment, the total loss for refine-adapt is {:}".format(
test_loss_meter_adapt_v1
)
)
logger.log(
"In the online test enviornment, the total loss for easy-adapt is {:}".format(
test_loss_meter_adapt_v2
)
)
save_checkpoint(
{"all_w_containers": w_containers},
{
"test_loss_adapt_v1": test_loss_meter_adapt_v1.avg,
"test_loss_adapt_v2": test_loss_meter_adapt_v2.avg,
},
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
logger,
)