Updates
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#####################################################
|
||||
from .super_module import SuperRunMode
|
||||
from .super_module import IntSpaceType
|
||||
from .super_module import LayerOrder
|
||||
|
||||
from .super_module import SuperModule
|
||||
from .super_container import SuperSequential
|
||||
|
@@ -37,8 +37,7 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: IntSpaceType,
|
||||
output_dim: IntSpaceType,
|
||||
d_model: IntSpaceType,
|
||||
num_heads: IntSpaceType,
|
||||
qkv_bias: BoolSpaceType = False,
|
||||
mlp_hidden_multiplier: IntSpaceType = 4,
|
||||
@@ -48,40 +47,37 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
):
|
||||
super(SuperTransformerEncoderLayer, self).__init__()
|
||||
mha = SuperAttention(
|
||||
input_dim,
|
||||
input_dim,
|
||||
d_model,
|
||||
d_model,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
drop1 = nn.Dropout(drop or 0.0)
|
||||
norm1 = SuperLayerNorm1D(input_dim)
|
||||
mlp = SuperMLPv2(
|
||||
input_dim,
|
||||
d_model,
|
||||
hidden_multiplier=mlp_hidden_multiplier,
|
||||
out_features=output_dim,
|
||||
out_features=d_model,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
drop2 = nn.Dropout(drop or 0.0)
|
||||
norm2 = SuperLayerNorm1D(output_dim)
|
||||
if order is LayerOrder.PreNorm:
|
||||
self.norm1 = norm1
|
||||
self.norm1 = SuperLayerNorm1D(d_model)
|
||||
self.mha = mha
|
||||
self.drop1 = drop1
|
||||
self.norm2 = norm2
|
||||
self.drop1 = nn.Dropout(drop or 0.0)
|
||||
self.norm2 = SuperLayerNorm1D(d_model)
|
||||
self.mlp = mlp
|
||||
self.drop2 = drop2
|
||||
elif order is LayerOrder.PostNoem:
|
||||
self.drop2 = nn.Dropout(drop or 0.0)
|
||||
elif order is LayerOrder.PostNorm:
|
||||
self.mha = mha
|
||||
self.drop1 = drop1
|
||||
self.norm1 = norm1
|
||||
self.drop1 = nn.Dropout(drop or 0.0)
|
||||
self.norm1 = SuperLayerNorm1D(d_model)
|
||||
self.mlp = mlp
|
||||
self.drop2 = drop2
|
||||
self.norm2 = norm2
|
||||
self.drop2 = nn.Dropout(drop or 0.0)
|
||||
self.norm2 = SuperLayerNorm1D(d_model)
|
||||
else:
|
||||
raise ValueError("Unknown order: {:}".format(order))
|
||||
self._order = order
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
@@ -108,18 +104,19 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
return self.forward_raw(input)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if order is LayerOrder.PreNorm:
|
||||
if self._order is LayerOrder.PreNorm:
|
||||
x = self.norm1(input)
|
||||
x = x + self.drop1(self.mha(x))
|
||||
x = self.norm2(x)
|
||||
x = x + self.drop2(self.mlp(x))
|
||||
elif order is LayerOrder.PostNoem:
|
||||
elif self._order is LayerOrder.PostNorm:
|
||||
# multi-head attention
|
||||
x = x + self.drop1(self.mha(input))
|
||||
x = self.mha(input)
|
||||
x = x + self.drop1(x)
|
||||
x = self.norm1(x)
|
||||
# feed-forward layer
|
||||
x = x + self.drop2(self.mlp(x))
|
||||
x = self.norm2(x)
|
||||
else:
|
||||
raise ValueError("Unknown order: {:}".format(order))
|
||||
raise ValueError("Unknown order: {:}".format(self._order))
|
||||
return x
|
||||
|
Reference in New Issue
Block a user