Update LFNA

This commit is contained in:
D-X-Y
2021-05-12 20:32:50 +08:00
parent 06f4a1f1cf
commit 0b1ca45c44
8 changed files with 121 additions and 15 deletions

View File

@@ -74,6 +74,19 @@ class SuperLayerNorm1D(SuperModule):
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps)
def forward_with_container(self, input, container, prefix=[]):
super_weight_name = ".".join(prefix + ["weight"])
if container.has(super_weight_name):
weight = container.query(super_weight_name)
else:
weight = None
super_bias_name = ".".join(prefix + ["bias"])
if container.has(super_bias_name):
bias = container.query(super_bias_name)
else:
bias = None
return F.layer_norm(input, (self.in_dim,), weight, bias, self.eps)
def extra_repr(self) -> str:
return (
"shape={in_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(