Update LFNA with train/valid
This commit is contained in:
@@ -44,6 +44,7 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
self._append_meta_embed = dict(fixed=None, learnt=None)
|
||||
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
||||
|
||||
self._time_prob_drop = super_core.SuperDrop(dropout, (-1, 1), recover=False)
|
||||
# build transformer
|
||||
layers = []
|
||||
for ilayer in range(mha_depth):
|
||||
@@ -149,10 +150,12 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
meta_match = meta_match.view(batch, seq, -1)
|
||||
# create the probability
|
||||
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
|
||||
if self.training:
|
||||
time_probs[:, -1, :] = 0
|
||||
|
||||
x_time_probs = self._time_prob_drop(time_probs)
|
||||
# if self.training:
|
||||
# time_probs[:, -1, :] = 0
|
||||
unknown_token = self._unknown_token.view(1, 1, -1)
|
||||
raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token
|
||||
raw_meta_embed = x_time_probs * meta_match + (1 - x_time_probs) * unknown_token
|
||||
|
||||
meta_embed = self.meta_corrector(raw_meta_embed)
|
||||
# create joint embed
|
||||
|
Reference in New Issue
Block a user