Update LFNA version 1.0

This commit is contained in:
D-X-Y
2021-05-13 21:33:34 +08:00
parent 3d3a04705f
commit cfabd05de8
11 changed files with 340 additions and 299 deletions

View File

@@ -42,6 +42,7 @@ class SuperTransformerEncoderLayer(SuperModule):
qkv_bias: BoolSpaceType = False,
mlp_hidden_multiplier: IntSpaceType = 4,
drop: Optional[float] = None,
norm_affine: bool = True,
act_layer: Callable[[], nn.Module] = nn.GELU,
order: LayerOrder = LayerOrder.PreNorm,
):
@@ -62,19 +63,19 @@ class SuperTransformerEncoderLayer(SuperModule):
drop=drop,
)
if order is LayerOrder.PreNorm:
self.norm1 = SuperLayerNorm1D(d_model)
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mha = mha
self.drop1 = nn.Dropout(drop or 0.0)
self.norm2 = SuperLayerNorm1D(d_model)
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mlp = mlp
self.drop2 = nn.Dropout(drop or 0.0)
elif order is LayerOrder.PostNorm:
self.mha = mha
self.drop1 = nn.Dropout(drop or 0.0)
self.norm1 = SuperLayerNorm1D(d_model)
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mlp = mlp
self.drop2 = nn.Dropout(drop or 0.0)
self.norm2 = SuperLayerNorm1D(d_model)
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
else:
raise ValueError("Unknown order: {:}".format(order))
self._order = order