Update LFNA with train/valid

This commit is contained in:
D-X-Y
2021-05-17 07:39:24 +00:00
parent de8cf677d9
commit 5c851ac25a
5 changed files with 123 additions and 26 deletions

View File

@@ -44,6 +44,7 @@ class LFNA_Meta(super_core.SuperModule):
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._time_prob_drop = super_core.SuperDrop(dropout, (-1, 1), recover=False)
# build transformer
layers = []
for ilayer in range(mha_depth):
@@ -149,10 +150,12 @@ class LFNA_Meta(super_core.SuperModule):
meta_match = meta_match.view(batch, seq, -1)
# create the probability
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
if self.training:
time_probs[:, -1, :] = 0
x_time_probs = self._time_prob_drop(time_probs)
# if self.training:
# time_probs[:, -1, :] = 0
unknown_token = self._unknown_token.view(1, 1, -1)
raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token
raw_meta_embed = x_time_probs * meta_match + (1 - x_time_probs) * unknown_token
meta_embed = self.meta_corrector(raw_meta_embed)
# create joint embed