Update LFNA

This commit is contained in:
D-X-Y
2021-05-23 08:21:31 +00:00
parent 2a864ae705
commit 25dc78a7ce
5 changed files with 152 additions and 20 deletions

View File

@@ -107,11 +107,20 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
base_model.eval()
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
[seq_containers], _ = meta_model(time_seqs, None)
future_container = seq_containers[-2]
_, (future_x, future_y) = env(time_seqs[0, -2].item())
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
# For Debug
for idx in range(time_seqs.numel()):
future_container = seq_containers[idx]
_, (future_x, future_y) = env(time_seqs[0, idx].item())
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(
future_x, future_container
)
future_loss = criterion(future_y_hat, future_y)
logger.log(
"--> time={:.4f} -> loss={:.4f}".format(
time_seqs[0, idx].item(), future_loss.item()
)
)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
idx, len(env), future_loss.item()

View File

@@ -47,17 +47,17 @@ class LFNA_Meta(super_core.SuperModule):
self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._tscalar_embed = super_core.SuperDynamicPositionE(
time_embedding, scale=100
time_embedding, scale=500
)
# build transformer
self._trans_att = super_core.SuperQKVAttention(
time_embedding,
time_embedding,
time_embedding,
time_embedding,
4,
True,
self._trans_att = super_core.SuperQKVAttentionV2(
qk_att_dim=time_embedding,
in_v_dim=time_embedding,
hidden_dim=time_embedding,
num_heads=4,
proj_dim=time_embedding,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
@@ -166,9 +166,12 @@ class LFNA_Meta(super_core.SuperModule):
# timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
timestamp_q_embed = self._tscalar_embed(timestamps)
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
# timestamp_q_embed = self._tscalar_embed(timestamps)
# timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_qk_att_embed = self._tscalar_embed(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
)
# create the mask
mask = (
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
@@ -179,11 +182,13 @@ class LFNA_Meta(super_core.SuperModule):
> self._thresh
)
timestamp_embeds = self._trans_att(
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
timestamp_qk_att_embed, timestamp_v_embed, mask
)
relative_timestamps = timestamps - timestamps[:, :1]
relative_pos_embeds = self._tscalar_embed(relative_timestamps)
init_timestamp_embeds = torch.cat(
(timestamp_embeds, relative_pos_embeds), dim=-1
)
# relative_timestamps = timestamps - timestamps[:, :1]
# relative_pos_embeds = self._tscalar_embed(relative_timestamps)
init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1)
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
return corrected_embeds