Try a different model / LFNA
This commit is contained in:
@@ -99,7 +99,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||
future_time = timestamp.item()
|
||||
time_seqs = [
|
||||
future_time - iseq * env.timestamp_interval
|
||||
for iseq in range(args.seq_length * 2)
|
||||
for iseq in range(args.seq_length)
|
||||
]
|
||||
time_seqs.reverse()
|
||||
with torch.no_grad():
|
||||
@@ -107,30 +107,26 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||
base_model.eval()
|
||||
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
|
||||
[seq_containers], _ = meta_model(time_seqs, None)
|
||||
# For Debug
|
||||
for idx in range(time_seqs.numel()):
|
||||
future_container = seq_containers[idx]
|
||||
_, (future_x, future_y) = env(time_seqs[0, idx].item())
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(
|
||||
future_x, future_container
|
||||
)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
logger.log(
|
||||
"--> time={:.4f} -> loss={:.4f}".format(
|
||||
time_seqs[0, idx].item(), future_loss.item()
|
||||
)
|
||||
)
|
||||
future_container = seq_containers[-1]
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
logger.log(
|
||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||
idx, len(env), future_loss.item()
|
||||
)
|
||||
)
|
||||
meta_model.adapt(
|
||||
future_time,
|
||||
future_x,
|
||||
future_y,
|
||||
env.timestamp_interval,
|
||||
args.refine_lr,
|
||||
args.refine_epochs,
|
||||
)
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
for iseq in range(args.seq_length):
|
||||
time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
|
||||
print("-")
|
||||
|
||||
|
||||
@@ -156,6 +152,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
|
||||
last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
|
||||
per_epoch_time, start_time = AverageMeter(), time.time()
|
||||
device = args.device
|
||||
for iepoch in range(args.epochs):
|
||||
left_time = "Time Left: {:}".format(
|
||||
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
|
||||
@@ -163,32 +160,38 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], []
|
||||
optimizer.zero_grad()
|
||||
for ibatch in range(args.meta_batch):
|
||||
rand_index = random.randint(0, meta_model.meta_length - 1)
|
||||
timestamp = meta_model.meta_timestamps[rand_index]
|
||||
meta_embed = meta_model.super_meta_embed[rand_index]
|
||||
|
||||
timestamps, [container], time_embeds = meta_model(
|
||||
torch.unsqueeze(timestamp, dim=0), None, True
|
||||
)
|
||||
_, (inputs, targets) = xenv(timestamp.item())
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
# generate models one step ahead
|
||||
predictions = base_model.forward_with_container(inputs, container)
|
||||
total_meta_v1_losses.append(criterion(predictions, targets))
|
||||
# the matching loss
|
||||
match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embed)
|
||||
total_match_losses.append(match_loss)
|
||||
# generate models via memory
|
||||
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
|
||||
_, [seq_containers], _ = meta_model(
|
||||
None,
|
||||
torch.unsqueeze(
|
||||
meta_model.super_meta_embed[
|
||||
rand_index : rand_index + xenv.seq_length
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
False,
|
||||
)
|
||||
timestamps = meta_model.meta_timestamps[
|
||||
rand_index : rand_index + xenv.seq_length
|
||||
]
|
||||
meta_embeds = meta_model.super_meta_embed[
|
||||
rand_index : rand_index + xenv.seq_length
|
||||
]
|
||||
|
||||
_, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
|
||||
seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to(
|
||||
args.device
|
||||
)
|
||||
# generate models one step ahead
|
||||
[seq_containers], time_embeds = meta_model(
|
||||
torch.unsqueeze(timestamps, dim=0), None
|
||||
)
|
||||
for container, inputs, targets in zip(
|
||||
seq_containers, seq_inputs, seq_targets
|
||||
):
|
||||
predictions = base_model.forward_with_container(inputs, container)
|
||||
total_meta_v1_losses.append(criterion(predictions, targets))
|
||||
# the matching loss
|
||||
match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embeds)
|
||||
total_match_losses.append(match_loss)
|
||||
# generate models via memory
|
||||
[seq_containers], _ = meta_model(None, torch.unsqueeze(meta_embeds, dim=0))
|
||||
seq_inputs, seq_targets = seq_inputs.to(device), seq_targets.to(device)
|
||||
for container, inputs, targets in zip(
|
||||
seq_containers, seq_inputs, seq_targets
|
||||
):
|
||||
@@ -250,7 +253,14 @@ def main(args):
|
||||
|
||||
# pre-train the hypernetwork
|
||||
timestamps = train_env.get_timestamp(None)
|
||||
meta_model = LFNA_Meta(shape_container, args.layer_dim, args.time_dim, timestamps)
|
||||
meta_model = LFNA_Meta(
|
||||
shape_container,
|
||||
args.layer_dim,
|
||||
args.time_dim,
|
||||
timestamps,
|
||||
seq_length=args.seq_length,
|
||||
interval=train_env.timestamp_interval,
|
||||
)
|
||||
meta_model = meta_model.to(args.device)
|
||||
|
||||
logger.log("The base-model has {:} weights.".format(base_model.numel()))
|
||||
|
||||
Reference in New Issue
Block a user