Fix bugs in xlayers

This commit is contained in:
D-X-Y
2021-05-22 16:41:54 +08:00
parent 97717d826e
commit bc42ab3c08
7 changed files with 197 additions and 39 deletions

View File

@@ -10,7 +10,8 @@ from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / "..").resolve()
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))

View File

@@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule):
layer_embedding,
time_embedding,
meta_timestamps,
mha_depth: int = 2,
mha_depth: int = 1,
dropout: float = 0.1,
):
super(LFNA_Meta, self).__init__()
@@ -44,8 +44,21 @@ class LFNA_Meta(super_core.SuperModule):
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._time_prob_drop = super_core.SuperDrop(dropout, (-1, 1), recover=False)
self._tscalar_embed = super_core.SuperDynamicPositionE(
time_embedding, scale=100
)
# build transformer
self._trans_att = super_core.SuperQKVAttention(
time_embedding,
time_embedding,
time_embedding,
time_embedding,
4,
True,
attn_drop=None,
proj_drop=dropout,
)
layers = []
for ilayer in range(mha_depth):
layers.append(
@@ -74,15 +87,9 @@ class LFNA_Meta(super_core.SuperModule):
self._generator = get_model(**model_kwargs)
# print("generator: {:}".format(self._generator))
# unknown token
self.register_parameter(
"_unknown_token",
torch.nn.Parameter(torch.Tensor(1, time_embedding)),
)
# initialization
trunc_normal_(
[self._super_layer_embed, self._super_meta_embed, self._unknown_token],
[self._super_layer_embed, self._super_meta_embed],
std=0.02,
)
@@ -136,28 +143,21 @@ class LFNA_Meta(super_core.SuperModule):
(self._append_meta_embed["fixed"], meta_embed), dim=0
)
def forward_raw(self, timestamps):
def _obtain_time_embed(self, timestamps):
# timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape
timestamps = timestamps.unsqueeze(dim=-1)
meta_timestamps = self.meta_timestamps.view(1, 1, -1)
time_diffs = timestamps - meta_timestamps
time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1)
# select corresponding meta-knowledge
meta_match = torch.index_select(
self.super_meta_embed, dim=0, index=time_match_i.view(-1)
timestamp_q_embed = self._tscalar_embed(timestamps)
timestamp_k_embed = self._tscalar_embed(self.meta_timestamps.view(1, -1))
timestamp_v_embed = self.super_meta_embed.unsqueeze(dim=0)
timestamp_embeds = self._trans_att(
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed
)
meta_match = meta_match.view(batch, seq, -1)
# create the probability
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
corrected_embeds = self.meta_corrector(timestamp_embeds)
return corrected_embeds
x_time_probs = self._time_prob_drop(time_probs)
# if self.training:
# time_probs[:, -1, :] = 0
unknown_token = self._unknown_token.view(1, 1, -1)
raw_meta_embed = x_time_probs * meta_match + (1 - x_time_probs) * unknown_token
meta_embed = self.meta_corrector(raw_meta_embed)
def forward_raw(self, timestamps):
batch, seq = timestamps.shape
meta_embed = self._obtain_time_embed(timestamps)
# create joint embed
num_layer, _ = self._super_layer_embed.shape
meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1)