Re-organize GeMOSA
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
#####################################################
|
||||
# Learning to Generate Model One Step Ahead #
|
||||
#####################################################
|
||||
# python exps/GeMOSA/lfna.py --env_version v1 --workers 0
|
||||
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
|
||||
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
|
||||
# python exps/GeMOSA/main.py --env_version v1 --workers 0
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData
|
||||
from meta_model import MetaModelV1
|
||||
|
||||
|
||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
|
||||
def online_evaluate(
|
||||
env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False
|
||||
):
|
||||
logger.log("Online evaluate: {:}".format(env))
|
||||
loss_meter = AverageMeter()
|
||||
w_containers = dict()
|
||||
@@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
|
||||
with torch.no_grad():
|
||||
meta_model.eval()
|
||||
base_model.eval()
|
||||
[future_container], time_embeds = meta_model(
|
||||
future_time.to(args.device).view(-1), None, False
|
||||
future_time_embed = meta_model.gen_time_embed(
|
||||
future_time.to(args.device).view(-1)
|
||||
)
|
||||
[future_container] = meta_model.gen_model(future_time_embed)
|
||||
if save:
|
||||
w_containers[idx] = future_container.no_grad_clone()
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
loss_meter.update(future_loss.item())
|
||||
refine, post_refine_loss = meta_model.adapt(
|
||||
base_model,
|
||||
criterion,
|
||||
future_time.item(),
|
||||
future_x,
|
||||
future_y,
|
||||
args.refine_lr,
|
||||
args.refine_epochs,
|
||||
{"param": time_embeds, "loss": future_loss.item()},
|
||||
)
|
||||
if easy_adapt:
|
||||
meta_model.easy_adapt(future_time.item(), future_time_embed)
|
||||
refine, post_refine_loss = False, -1
|
||||
else:
|
||||
refine, post_refine_loss = meta_model.adapt(
|
||||
base_model,
|
||||
criterion,
|
||||
future_time.item(),
|
||||
future_x,
|
||||
future_y,
|
||||
args.refine_lr,
|
||||
args.refine_epochs,
|
||||
{"param": future_time_embed, "loss": future_loss.item()},
|
||||
)
|
||||
logger.log(
|
||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||
idx, len(env), future_loss.item()
|
||||
@@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
|
||||
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
|
||||
generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps)
|
||||
|
||||
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
|
||||
|
||||
@@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
||||
)
|
||||
# future loss
|
||||
total_future_losses, total_present_losses = [], []
|
||||
future_containers, _ = meta_model(
|
||||
None, generated_time_embeds[batch_indexes], False
|
||||
)
|
||||
present_containers, _ = meta_model(
|
||||
None, meta_model.super_meta_embed[batch_indexes], False
|
||||
future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes])
|
||||
present_containers = meta_model.gen_model(
|
||||
meta_model.super_meta_embed[batch_indexes]
|
||||
)
|
||||
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
|
||||
_, (inputs, targets) = xenv(time_step)
|
||||
@@ -216,13 +220,34 @@ def main(args):
|
||||
# try to evaluate once
|
||||
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
||||
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
||||
"""
|
||||
w_containers, loss_meter = online_evaluate(
|
||||
all_env, meta_model, base_model, criterion, args, logger, True
|
||||
)
|
||||
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
||||
"""
|
||||
_, test_loss_meter_adapt_v1 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, args, logger, False, False
|
||||
)
|
||||
_, test_loss_meter_adapt_v2 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, args, logger, False, True
|
||||
)
|
||||
logger.log(
|
||||
"In the online test enviornment, the total loss for refine-adapt is {:}".format(
|
||||
test_loss_meter_adapt_v1
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
"In the online test enviornment, the total loss for easy-adapt is {:}".format(
|
||||
test_loss_meter_adapt_v2
|
||||
)
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
{"all_w_containers": w_containers},
|
||||
{
|
||||
"test_loss_adapt_v1": test_loss_meter_adapt_v1.avg,
|
||||
"test_loss_adapt_v2": test_loss_meter_adapt_v2.avg,
|
||||
},
|
||||
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
|
||||
logger,
|
||||
)
|
||||
|
@@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule):
|
||||
batch_containers.append(
|
||||
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
|
||||
)
|
||||
return batch_containers, time_embeds
|
||||
return batch_containers
|
||||
|
||||
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
|
||||
raise NotImplementedError
|
||||
@@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule):
|
||||
def forward_candidate(self, input):
|
||||
raise NotImplementedError
|
||||
|
||||
def easy_adapt(self, timestamp, time_embed):
|
||||
with torch.no_grad():
|
||||
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
|
||||
self.replace_append_learnt(None, None)
|
||||
self.append_fixed(timestamp, time_embed)
|
||||
|
||||
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
|
||||
distance = self.get_closest_meta_distance(timestamp)
|
||||
if distance + self._interval * 1e-2 <= self._interval:
|
||||
@@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule):
|
||||
best_new_param = new_param.detach().clone()
|
||||
for iepoch in range(epochs):
|
||||
optimizer.zero_grad()
|
||||
_, time_embed = self(timestamp.view(1), None)
|
||||
time_embed = self.gen_time_embed(timestamp.view(1))
|
||||
match_loss = criterion(new_param, time_embed)
|
||||
|
||||
[container], time_embed = self(None, new_param.view(1, -1))
|
||||
[container] = self.gen_model(new_param.view(1, -1))
|
||||
y_hat = base_model.forward_with_container(x, container)
|
||||
meta_loss = criterion(y_hat, y)
|
||||
loss = meta_loss + match_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
|
||||
if meta_loss.item() < best_loss:
|
||||
with torch.no_grad():
|
||||
best_loss = meta_loss.item()
|
||||
best_new_param = new_param.detach().clone()
|
||||
with torch.no_grad():
|
||||
self.replace_append_learnt(None, None)
|
||||
self.append_fixed(timestamp, best_new_param)
|
||||
self.easy_adapt(timestamp, best_new_param)
|
||||
return True, best_loss
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
@@ -191,6 +191,8 @@ def visualize_env(save_dir, version):
|
||||
allxs.append(allx)
|
||||
allys.append(ally)
|
||||
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
||||
print("env: {:}".format(dynamic_env))
|
||||
print("oracle_map: {:}".format(dynamic_env.oracle_map))
|
||||
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
|
||||
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
|
||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
|
Reference in New Issue
Block a user