Update LFNA
This commit is contained in:
@@ -107,11 +107,20 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||
base_model.eval()
|
||||
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
|
||||
[seq_containers], _ = meta_model(time_seqs, None)
|
||||
future_container = seq_containers[-2]
|
||||
_, (future_x, future_y) = env(time_seqs[0, -2].item())
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
# For Debug
|
||||
for idx in range(time_seqs.numel()):
|
||||
future_container = seq_containers[idx]
|
||||
_, (future_x, future_y) = env(time_seqs[0, idx].item())
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(
|
||||
future_x, future_container
|
||||
)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
logger.log(
|
||||
"--> time={:.4f} -> loss={:.4f}".format(
|
||||
time_seqs[0, idx].item(), future_loss.item()
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||
idx, len(env), future_loss.item()
|
||||
|
@@ -47,17 +47,17 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
||||
|
||||
self._tscalar_embed = super_core.SuperDynamicPositionE(
|
||||
time_embedding, scale=100
|
||||
time_embedding, scale=500
|
||||
)
|
||||
|
||||
# build transformer
|
||||
self._trans_att = super_core.SuperQKVAttention(
|
||||
time_embedding,
|
||||
time_embedding,
|
||||
time_embedding,
|
||||
time_embedding,
|
||||
4,
|
||||
True,
|
||||
self._trans_att = super_core.SuperQKVAttentionV2(
|
||||
qk_att_dim=time_embedding,
|
||||
in_v_dim=time_embedding,
|
||||
hidden_dim=time_embedding,
|
||||
num_heads=4,
|
||||
proj_dim=time_embedding,
|
||||
qkv_bias=True,
|
||||
attn_drop=None,
|
||||
proj_drop=dropout,
|
||||
)
|
||||
@@ -166,9 +166,12 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
# timestamps is a batch of sequence of timestamps
|
||||
batch, seq = timestamps.shape
|
||||
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
||||
timestamp_q_embed = self._tscalar_embed(timestamps)
|
||||
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
|
||||
# timestamp_q_embed = self._tscalar_embed(timestamps)
|
||||
# timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
|
||||
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
||||
timestamp_qk_att_embed = self._tscalar_embed(
|
||||
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
|
||||
)
|
||||
# create the mask
|
||||
mask = (
|
||||
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
|
||||
@@ -179,11 +182,13 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
> self._thresh
|
||||
)
|
||||
timestamp_embeds = self._trans_att(
|
||||
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
|
||||
timestamp_qk_att_embed, timestamp_v_embed, mask
|
||||
)
|
||||
relative_timestamps = timestamps - timestamps[:, :1]
|
||||
relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
||||
init_timestamp_embeds = torch.cat(
|
||||
(timestamp_embeds, relative_pos_embeds), dim=-1
|
||||
)
|
||||
# relative_timestamps = timestamps - timestamps[:, :1]
|
||||
# relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
||||
init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1)
|
||||
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
|
||||
return corrected_embeds
|
||||
|
||||
|
Reference in New Issue
Block a user