Update models
This commit is contained in:
@@ -14,6 +14,13 @@ IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
|
||||
BoolSpaceType = Union[bool, spaces.Categorical]
|
||||
|
||||
|
||||
class LayerOrder(Enum):
|
||||
"""This class defines the enumerations for order of operation in a residual or normalization-based layer."""
|
||||
|
||||
PreNorm = "pre-norm"
|
||||
PostNorm = "post-norm"
|
||||
|
||||
|
||||
class SuperRunMode(Enum):
|
||||
"""This class defines the enumerations for Super Model Running Mode."""
|
||||
|
||||
|
@@ -15,6 +15,7 @@ import torch.nn.functional as F
|
||||
import spaces
|
||||
from .super_module import IntSpaceType
|
||||
from .super_module import BoolSpaceType
|
||||
from .super_module import LayerOrder
|
||||
from .super_module import SuperModule
|
||||
from .super_linear import SuperMLPv2
|
||||
from .super_norm import SuperLayerNorm1D
|
||||
@@ -30,7 +31,8 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
- PyTorch Implementation: https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
|
||||
|
||||
Details:
|
||||
MHA -> residual -> norm -> MLP -> residual -> norm
|
||||
the original post-norm version: MHA -> residual -> norm -> MLP -> residual -> norm
|
||||
the pre-norm version: norm -> MHA -> residual -> norm -> MLP -> residual
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -42,9 +44,10 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
mlp_hidden_multiplier: IntSpaceType = 4,
|
||||
drop: Optional[float] = None,
|
||||
act_layer: Callable[[], nn.Module] = nn.GELU,
|
||||
order: LayerOrder = LayerOrder.PreNorm,
|
||||
):
|
||||
super(SuperTransformerEncoderLayer, self).__init__()
|
||||
self.mha = SuperAttention(
|
||||
mha = SuperAttention(
|
||||
input_dim,
|
||||
input_dim,
|
||||
num_heads=num_heads,
|
||||
@@ -52,17 +55,33 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
attn_drop=drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop1 = nn.Dropout(drop or 0.0)
|
||||
self.norm1 = SuperLayerNorm1D(input_dim)
|
||||
self.mlp = SuperMLPv2(
|
||||
drop1 = nn.Dropout(drop or 0.0)
|
||||
norm1 = SuperLayerNorm1D(input_dim)
|
||||
mlp = SuperMLPv2(
|
||||
input_dim,
|
||||
hidden_multiplier=mlp_hidden_multiplier,
|
||||
out_features=output_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.drop2 = nn.Dropout(drop or 0.0)
|
||||
self.norm2 = SuperLayerNorm1D(output_dim)
|
||||
drop2 = nn.Dropout(drop or 0.0)
|
||||
norm2 = SuperLayerNorm1D(output_dim)
|
||||
if order is LayerOrder.PreNorm:
|
||||
self.norm1 = norm1
|
||||
self.mha = mha
|
||||
self.drop1 = drop1
|
||||
self.norm2 = norm2
|
||||
self.mlp = mlp
|
||||
self.drop2 = drop2
|
||||
elif order is LayerOrder.PostNoem:
|
||||
self.mha = mha
|
||||
self.drop1 = drop1
|
||||
self.norm1 = norm1
|
||||
self.mlp = mlp
|
||||
self.drop2 = drop2
|
||||
self.norm2 = norm2
|
||||
else:
|
||||
raise ValueError("Unknown order: {:}".format(order))
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
@@ -89,12 +108,18 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
return self.forward_raw(input)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# multi-head attention
|
||||
x = self.mha(input)
|
||||
x = x + self.drop1(x)
|
||||
x = self.norm1(x)
|
||||
# feed-forward layer
|
||||
x = self.mlp(x)
|
||||
x = x + self.drop2(x)
|
||||
x = self.norm2(x)
|
||||
if order is LayerOrder.PreNorm:
|
||||
x = self.norm1(input)
|
||||
x = x + self.drop1(self.mha(x))
|
||||
x = self.norm2(x)
|
||||
x = x + self.drop2(self.mlp(x))
|
||||
elif order is LayerOrder.PostNoem:
|
||||
# multi-head attention
|
||||
x = x + self.drop1(self.mha(input))
|
||||
x = self.norm1(x)
|
||||
# feed-forward layer
|
||||
x = x + self.drop2(self.mlp(x))
|
||||
x = self.norm2(x)
|
||||
else:
|
||||
raise ValueError("Unknown order: {:}".format(order))
|
||||
return x
|
||||
|
Reference in New Issue
Block a user