Fix bugs in xlayers

This commit is contained in:
D-X-Y
2021-05-22 16:41:54 +08:00
parent 97717d826e
commit bc42ab3c08
7 changed files with 197 additions and 39 deletions

View File

@@ -35,11 +35,13 @@ class SuperDynamicPositionE(SuperModule):
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
import pdb
pdb.set_trace()
print("---")
return F.linear(input, self._super_weight, self._super_bias)
positions = torch.unsqueeze(input * self._scale, dim=-1)
divisions = torch.reshape(
self._div_term, [1] * input.ndim + [self._div_term.numel()]
)
values = positions / divisions
embeds = torch.cat((torch.sin(values), torch.cos(values)), dim=-1)
return embeds
def extra_repr(self) -> str:
return "scale={:}, dim={:}".format(self._scale, self._dimension)