Update LFNA with resume

This commit is contained in:
D-X-Y
2021-05-17 04:33:40 +00:00
parent b11cfe263d
commit de8cf677d9
3 changed files with 62 additions and 18 deletions

View File

@@ -102,9 +102,11 @@ class LFNA_Meta(super_core.SuperModule):
return torch.cat(meta_embed)
def create_meta_embed(self):
param = torch.nn.Parameter(torch.Tensor(1, self._time_embed_dim))
param = torch.Tensor(1, self._time_embed_dim)
trunc_normal_(param, std=0.02)
return param.to(self._super_meta_embed.device)
param = param.to(self._super_meta_embed.device)
param = torch.nn.Parameter(param, True)
return param
def get_closest_meta_distance(self, timestamp):
with torch.no_grad():
@@ -112,12 +114,14 @@ class LFNA_Meta(super_core.SuperModule):
return torch.min(distances).item()
def replace_append_learnt(self, timestamp, meta_embed):
self._append_meta_embed["learnt"] = meta_embed
self._append_meta_timestamps["learnt"] = timestamp
self._append_meta_embed["learnt"] = meta_embed
def append_fixed(self, timestamp, meta_embed):
with torch.no_grad():
timestamp, meta_embed = timestamp.clone(), meta_embed.clone()
device = self._super_meta_embed.device
timestamp = timestamp.detach().clone().to(device)
meta_embed = meta_embed.detach().clone().to(device)
if self._append_meta_timestamps["fixed"] is None:
self._append_meta_timestamps["fixed"] = timestamp
else: