Update LFNA
This commit is contained in:
@@ -74,6 +74,19 @@ class SuperLayerNorm1D(SuperModule):
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps)
|
||||
|
||||
def forward_with_container(self, input, container, prefix=[]):
|
||||
super_weight_name = ".".join(prefix + ["weight"])
|
||||
if container.has(super_weight_name):
|
||||
weight = container.query(super_weight_name)
|
||||
else:
|
||||
weight = None
|
||||
super_bias_name = ".".join(prefix + ["bias"])
|
||||
if container.has(super_bias_name):
|
||||
bias = container.query(super_bias_name)
|
||||
else:
|
||||
bias = None
|
||||
return F.layer_norm(input, (self.in_dim,), weight, bias, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"shape={in_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(
|
||||
|
Reference in New Issue
Block a user