Try a different model / LFNA

This commit is contained in:
D-X-Y
2021-05-23 23:09:14 +08:00
parent 25dc78a7ce
commit 9135667cc1
2 changed files with 123 additions and 74 deletions

View File

@@ -99,7 +99,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
future_time = timestamp.item()
time_seqs = [
future_time - iseq * env.timestamp_interval
for iseq in range(args.seq_length * 2)
for iseq in range(args.seq_length)
]
time_seqs.reverse()
with torch.no_grad():
@@ -107,30 +107,26 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
base_model.eval()
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
[seq_containers], _ = meta_model(time_seqs, None)
# For Debug
for idx in range(time_seqs.numel()):
future_container = seq_containers[idx]
_, (future_x, future_y) = env(time_seqs[0, idx].item())
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(
future_x, future_container
)
future_loss = criterion(future_y_hat, future_y)
logger.log(
"--> time={:.4f} -> loss={:.4f}".format(
time_seqs[0, idx].item(), future_loss.item()
)
)
future_container = seq_containers[-1]
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
idx, len(env), future_loss.item()
)
)
meta_model.adapt(
future_time,
future_x,
future_y,
env.timestamp_interval,
args.refine_lr,
args.refine_epochs,
)
import pdb
pdb.set_trace()
for iseq in range(args.seq_length):
time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
print("-")
@@ -156,6 +152,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
per_epoch_time, start_time = AverageMeter(), time.time()
device = args.device
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
@@ -163,32 +160,38 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], []
optimizer.zero_grad()
for ibatch in range(args.meta_batch):
rand_index = random.randint(0, meta_model.meta_length - 1)
timestamp = meta_model.meta_timestamps[rand_index]
meta_embed = meta_model.super_meta_embed[rand_index]
timestamps, [container], time_embeds = meta_model(
torch.unsqueeze(timestamp, dim=0), None, True
)
_, (inputs, targets) = xenv(timestamp.item())
inputs, targets = inputs.to(device), targets.to(device)
# generate models one step ahead
predictions = base_model.forward_with_container(inputs, container)
total_meta_v1_losses.append(criterion(predictions, targets))
# the matching loss
match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embed)
total_match_losses.append(match_loss)
# generate models via memory
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
_, [seq_containers], _ = meta_model(
None,
torch.unsqueeze(
meta_model.super_meta_embed[
rand_index : rand_index + xenv.seq_length
],
dim=0,
),
False,
)
timestamps = meta_model.meta_timestamps[
rand_index : rand_index + xenv.seq_length
]
meta_embeds = meta_model.super_meta_embed[
rand_index : rand_index + xenv.seq_length
]
_, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to(
args.device
)
# generate models one step ahead
[seq_containers], time_embeds = meta_model(
torch.unsqueeze(timestamps, dim=0), None
)
for container, inputs, targets in zip(
seq_containers, seq_inputs, seq_targets
):
predictions = base_model.forward_with_container(inputs, container)
total_meta_v1_losses.append(criterion(predictions, targets))
# the matching loss
match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embeds)
total_match_losses.append(match_loss)
# generate models via memory
[seq_containers], _ = meta_model(None, torch.unsqueeze(meta_embeds, dim=0))
seq_inputs, seq_targets = seq_inputs.to(device), seq_targets.to(device)
for container, inputs, targets in zip(
seq_containers, seq_inputs, seq_targets
):
@@ -250,7 +253,14 @@ def main(args):
# pre-train the hypernetwork
timestamps = train_env.get_timestamp(None)
meta_model = LFNA_Meta(shape_container, args.layer_dim, args.time_dim, timestamps)
meta_model = LFNA_Meta(
shape_container,
args.layer_dim,
args.time_dim,
timestamps,
seq_length=args.seq_length,
interval=train_env.timestamp_interval,
)
meta_model = meta_model.to(args.device)
logger.log("The base-model has {:} weights.".format(base_model.numel()))