Fix bugs in xlayers
This commit is contained in:
@@ -10,7 +10,8 @@ from tqdm import tqdm
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / "..").resolve()
|
||||
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||
print("LIB-DIR: {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
|
@@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
layer_embedding,
|
||||
time_embedding,
|
||||
meta_timestamps,
|
||||
mha_depth: int = 2,
|
||||
mha_depth: int = 1,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super(LFNA_Meta, self).__init__()
|
||||
@@ -44,8 +44,21 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
self._append_meta_embed = dict(fixed=None, learnt=None)
|
||||
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
||||
|
||||
self._time_prob_drop = super_core.SuperDrop(dropout, (-1, 1), recover=False)
|
||||
self._tscalar_embed = super_core.SuperDynamicPositionE(
|
||||
time_embedding, scale=100
|
||||
)
|
||||
|
||||
# build transformer
|
||||
self._trans_att = super_core.SuperQKVAttention(
|
||||
time_embedding,
|
||||
time_embedding,
|
||||
time_embedding,
|
||||
time_embedding,
|
||||
4,
|
||||
True,
|
||||
attn_drop=None,
|
||||
proj_drop=dropout,
|
||||
)
|
||||
layers = []
|
||||
for ilayer in range(mha_depth):
|
||||
layers.append(
|
||||
@@ -74,15 +87,9 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
self._generator = get_model(**model_kwargs)
|
||||
# print("generator: {:}".format(self._generator))
|
||||
|
||||
# unknown token
|
||||
self.register_parameter(
|
||||
"_unknown_token",
|
||||
torch.nn.Parameter(torch.Tensor(1, time_embedding)),
|
||||
)
|
||||
|
||||
# initialization
|
||||
trunc_normal_(
|
||||
[self._super_layer_embed, self._super_meta_embed, self._unknown_token],
|
||||
[self._super_layer_embed, self._super_meta_embed],
|
||||
std=0.02,
|
||||
)
|
||||
|
||||
@@ -136,28 +143,21 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
(self._append_meta_embed["fixed"], meta_embed), dim=0
|
||||
)
|
||||
|
||||
def forward_raw(self, timestamps):
|
||||
def _obtain_time_embed(self, timestamps):
|
||||
# timestamps is a batch of sequence of timestamps
|
||||
batch, seq = timestamps.shape
|
||||
timestamps = timestamps.unsqueeze(dim=-1)
|
||||
meta_timestamps = self.meta_timestamps.view(1, 1, -1)
|
||||
time_diffs = timestamps - meta_timestamps
|
||||
time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1)
|
||||
# select corresponding meta-knowledge
|
||||
meta_match = torch.index_select(
|
||||
self.super_meta_embed, dim=0, index=time_match_i.view(-1)
|
||||
timestamp_q_embed = self._tscalar_embed(timestamps)
|
||||
timestamp_k_embed = self._tscalar_embed(self.meta_timestamps.view(1, -1))
|
||||
timestamp_v_embed = self.super_meta_embed.unsqueeze(dim=0)
|
||||
timestamp_embeds = self._trans_att(
|
||||
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed
|
||||
)
|
||||
meta_match = meta_match.view(batch, seq, -1)
|
||||
# create the probability
|
||||
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
|
||||
corrected_embeds = self.meta_corrector(timestamp_embeds)
|
||||
return corrected_embeds
|
||||
|
||||
x_time_probs = self._time_prob_drop(time_probs)
|
||||
# if self.training:
|
||||
# time_probs[:, -1, :] = 0
|
||||
unknown_token = self._unknown_token.view(1, 1, -1)
|
||||
raw_meta_embed = x_time_probs * meta_match + (1 - x_time_probs) * unknown_token
|
||||
|
||||
meta_embed = self.meta_corrector(raw_meta_embed)
|
||||
def forward_raw(self, timestamps):
|
||||
batch, seq = timestamps.shape
|
||||
meta_embed = self._obtain_time_embed(timestamps)
|
||||
# create joint embed
|
||||
num_layer, _ = self._super_layer_embed.shape
|
||||
meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1)
|
||||
|
Reference in New Issue
Block a user