Complete LFNA 1.0

This commit is contained in:
D-X-Y
2021-05-14 00:36:37 +08:00
parent c2fa181bc5
commit b81ef2dd74
4 changed files with 311 additions and 19 deletions

View File

@@ -17,7 +17,7 @@ class LFNA_Meta(super_core.SuperModule):
def __init__(
self,
shape_container,
layer_embeding,
layer_embedding,
time_embedding,
meta_timestamps,
mha_depth: int = 2,
@@ -33,13 +33,16 @@ class LFNA_Meta(super_core.SuperModule):
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)),
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)),
)
self.register_parameter(
"_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)),
)
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
self._time_embed_dim = time_embedding
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
# build transformer
layers = []
@@ -60,9 +63,9 @@ class LFNA_Meta(super_core.SuperModule):
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_embeding + time_embedding,
input_dim=layer_embedding + time_embedding,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_embeding + time_embedding) * 2] * 3,
hidden_dims=[(layer_embedding + time_embedding) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=dropout,
@@ -82,21 +85,68 @@ class LFNA_Meta(super_core.SuperModule):
std=0.02,
)
@property
def meta_timestamps(self):
meta_timestamps = [self._meta_timestamps]
for key in ("fixed", "learnt"):
if self._append_meta_timestamps[key] is not None:
meta_timestamps.append(self._append_meta_timestamps[key])
return torch.cat(meta_timestamps)
@property
def super_meta_embed(self):
meta_embed = [self._super_meta_embed]
for key in ("fixed", "learnt"):
if self._append_meta_embed[key] is not None:
meta_embed.append(self._append_meta_embed[key])
return torch.cat(meta_embed)
def create_meta_embed(self):
param = torch.nn.Parameter(torch.Tensor(1, self._time_embed_dim))
trunc_normal_(param, std=0.02)
return param.to(self._super_meta_embed.device)
def get_closest_meta_distance(self, timestamp):
with torch.no_grad():
distances = torch.abs(self.meta_timestamps - timestamp)
return torch.min(distances).item()
def replace_append_learnt(self, timestamp, meta_embed):
self._append_meta_embed["learnt"] = meta_embed
self._append_meta_timestamps["learnt"] = timestamp
def append_fixed(self, timestamp, meta_embed):
with torch.no_grad():
timestamp, meta_embed = timestamp.clone(), meta_embed.clone()
if self._append_meta_timestamps["fixed"] is None:
self._append_meta_timestamps["fixed"] = timestamp
else:
self._append_meta_timestamps["fixed"] = torch.cat(
(self._append_meta_timestamps["fixed"], timestamp), dim=0
)
if self._append_meta_embed["fixed"] is None:
self._append_meta_embed["fixed"] = meta_embed
else:
self._append_meta_embed["fixed"] = torch.cat(
(self._append_meta_embed["fixed"], meta_embed), dim=0
)
def forward_raw(self, timestamps):
# timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape
timestamps = timestamps.unsqueeze(dim=-1)
meta_timestamps = self._meta_timestamps.view(1, 1, -1)
meta_timestamps = self.meta_timestamps.view(1, 1, -1)
time_diffs = timestamps - meta_timestamps
time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1)
# select corresponding meta-knowledge
meta_match = torch.index_select(
self._super_meta_embed, dim=0, index=time_match_i.view(-1)
self.super_meta_embed, dim=0, index=time_match_i.view(-1)
)
meta_match = meta_match.view(batch, seq, -1)
# create the probability
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
time_probs[:, -1, :] = 0
if self.training:
time_probs[:, -1, :] = 0
unknown_token = self._unknown_token.view(1, 1, -1)
raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token