Update LFNA with resume
This commit is contained in:
@@ -102,9 +102,11 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
return torch.cat(meta_embed)
|
||||
|
||||
def create_meta_embed(self):
|
||||
param = torch.nn.Parameter(torch.Tensor(1, self._time_embed_dim))
|
||||
param = torch.Tensor(1, self._time_embed_dim)
|
||||
trunc_normal_(param, std=0.02)
|
||||
return param.to(self._super_meta_embed.device)
|
||||
param = param.to(self._super_meta_embed.device)
|
||||
param = torch.nn.Parameter(param, True)
|
||||
return param
|
||||
|
||||
def get_closest_meta_distance(self, timestamp):
|
||||
with torch.no_grad():
|
||||
@@ -112,12 +114,14 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
return torch.min(distances).item()
|
||||
|
||||
def replace_append_learnt(self, timestamp, meta_embed):
|
||||
self._append_meta_embed["learnt"] = meta_embed
|
||||
self._append_meta_timestamps["learnt"] = timestamp
|
||||
self._append_meta_embed["learnt"] = meta_embed
|
||||
|
||||
def append_fixed(self, timestamp, meta_embed):
|
||||
with torch.no_grad():
|
||||
timestamp, meta_embed = timestamp.clone(), meta_embed.clone()
|
||||
device = self._super_meta_embed.device
|
||||
timestamp = timestamp.detach().clone().to(device)
|
||||
meta_embed = meta_embed.detach().clone().to(device)
|
||||
if self._append_meta_timestamps["fixed"] is None:
|
||||
self._append_meta_timestamps["fixed"] = timestamp
|
||||
else:
|
||||
|
Reference in New Issue
Block a user