Re-organize GeMOSA

This commit is contained in:
D-X-Y
2021-05-27 15:44:01 +08:00
parent 8961215416
commit 6da60664f5
10 changed files with 354 additions and 350 deletions

View File

@@ -1,10 +1,9 @@
#####################################################
# Learning to Generate Model One Step Ahead #
#####################################################
# python exps/GeMOSA/lfna.py --env_version v1 --workers 0
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
# python exps/GeMOSA/main.py --env_version v1 --workers 0
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData
from meta_model import MetaModelV1
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
def online_evaluate(
env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False
):
logger.log("Online evaluate: {:}".format(env))
loss_meter = AverageMeter()
w_containers = dict()
@@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
with torch.no_grad():
meta_model.eval()
base_model.eval()
[future_container], time_embeds = meta_model(
future_time.to(args.device).view(-1), None, False
future_time_embed = meta_model.gen_time_embed(
future_time.to(args.device).view(-1)
)
[future_container] = meta_model.gen_model(future_time_embed)
if save:
w_containers[idx] = future_container.no_grad_clone()
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
loss_meter.update(future_loss.item())
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
args.refine_lr,
args.refine_epochs,
{"param": time_embeds, "loss": future_loss.item()},
)
if easy_adapt:
meta_model.easy_adapt(future_time.item(), future_time_embed)
refine, post_refine_loss = False, -1
else:
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
args.refine_lr,
args.refine_epochs,
{"param": future_time_embed, "loss": future_loss.item()},
)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
idx, len(env), future_loss.item()
@@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
)
optimizer.zero_grad()
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps)
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
@@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
)
# future loss
total_future_losses, total_present_losses = [], []
future_containers, _ = meta_model(
None, generated_time_embeds[batch_indexes], False
)
present_containers, _ = meta_model(
None, meta_model.super_meta_embed[batch_indexes], False
future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes])
present_containers = meta_model.gen_model(
meta_model.super_meta_embed[batch_indexes]
)
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
_, (inputs, targets) = xenv(time_step)
@@ -216,13 +220,34 @@ def main(args):
# try to evaluate once
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
"""
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
"""
_, test_loss_meter_adapt_v1 = online_evaluate(
valid_env, meta_model, base_model, criterion, args, logger, False, False
)
_, test_loss_meter_adapt_v2 = online_evaluate(
valid_env, meta_model, base_model, criterion, args, logger, False, True
)
logger.log(
"In the online test enviornment, the total loss for refine-adapt is {:}".format(
test_loss_meter_adapt_v1
)
)
logger.log(
"In the online test enviornment, the total loss for easy-adapt is {:}".format(
test_loss_meter_adapt_v2
)
)
save_checkpoint(
{"all_w_containers": w_containers},
{
"test_loss_adapt_v1": test_loss_meter_adapt_v1.avg,
"test_loss_adapt_v2": test_loss_meter_adapt_v2.avg,
},
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
logger,
)

View File

@@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule):
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
return batch_containers, time_embeds
return batch_containers
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
raise NotImplementedError
@@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule):
def forward_candidate(self, input):
raise NotImplementedError
def easy_adapt(self, timestamp, time_embed):
with torch.no_grad():
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, time_embed)
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
distance = self.get_closest_meta_distance(timestamp)
if distance + self._interval * 1e-2 <= self._interval:
@@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule):
best_new_param = new_param.detach().clone()
for iepoch in range(epochs):
optimizer.zero_grad()
_, time_embed = self(timestamp.view(1), None)
time_embed = self.gen_time_embed(timestamp.view(1))
match_loss = criterion(new_param, time_embed)
[container], time_embed = self(None, new_param.view(1, -1))
[container] = self.gen_model(new_param.view(1, -1))
y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss
loss.backward()
optimizer.step()
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
if meta_loss.item() < best_loss:
with torch.no_grad():
best_loss = meta_loss.item()
best_new_param = new_param.detach().clone()
with torch.no_grad():
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, best_new_param)
self.easy_adapt(timestamp, best_new_param)
return True, best_loss
def extra_repr(self) -> str:

View File

@@ -191,6 +191,8 @@ def visualize_env(save_dir, version):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
print("env: {:}".format(dynamic_env))
print("oracle_map: {:}".format(dynamic_env.oracle_map))
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):