Update xlayers
This commit is contained in:
@@ -161,6 +161,21 @@ class SuperSimpleLearnableNorm(SuperModule):
|
||||
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
|
||||
return tensor.sub_(mean).div_(std)
|
||||
|
||||
def forward_with_container(self, input, container, prefix=[]):
|
||||
if not self._inplace:
|
||||
tensor = input.clone()
|
||||
else:
|
||||
tensor = input
|
||||
mean_name = ".".join(prefix + ["_mean"])
|
||||
std_name = ".".join(prefix + ["_std"])
|
||||
mean, std = (
|
||||
container.query(mean_name).to(tensor.device),
|
||||
torch.abs(container.query(std_name).to(tensor.device)) + self._eps,
|
||||
)
|
||||
while mean.ndim < tensor.ndim:
|
||||
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
|
||||
return tensor.sub_(mean).div_(std)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "mean={mean}, std={std}, inplace={inplace}".format(
|
||||
mean=self._mean.item(), std=self._std.item(), inplace=self._inplace
|
||||
@@ -191,3 +206,6 @@ class SuperIdentity(SuperModule):
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "inplace={inplace}".format(inplace=self._inplace)
|
||||
|
||||
def forward_with_container(self, input, container, prefix=[]):
|
||||
return self.forward_raw(input)
|
||||
|
Reference in New Issue
Block a user