Update xlayers

This commit is contained in:
D-X-Y
2021-05-07 10:26:35 +08:00
parent f6a024a6ff
commit 80aaac4dfa
9 changed files with 333 additions and 83 deletions

View File

@@ -161,6 +161,21 @@ class SuperSimpleLearnableNorm(SuperModule):
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
return tensor.sub_(mean).div_(std)
def forward_with_container(self, input, container, prefix=[]):
if not self._inplace:
tensor = input.clone()
else:
tensor = input
mean_name = ".".join(prefix + ["_mean"])
std_name = ".".join(prefix + ["_std"])
mean, std = (
container.query(mean_name).to(tensor.device),
torch.abs(container.query(std_name).to(tensor.device)) + self._eps,
)
while mean.ndim < tensor.ndim:
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
return tensor.sub_(mean).div_(std)
def extra_repr(self) -> str:
return "mean={mean}, std={std}, inplace={inplace}".format(
mean=self._mean.item(), std=self._std.item(), inplace=self._inplace
@@ -191,3 +206,6 @@ class SuperIdentity(SuperModule):
def extra_repr(self) -> str:
return "inplace={inplace}".format(inplace=self._inplace)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)