This commit is contained in:
D-X-Y
2021-03-24 05:33:52 -07:00
parent 379b904203
commit 15dda79e3b
6 changed files with 60 additions and 58 deletions

View File

@@ -3,6 +3,7 @@
#####################################################
from .super_module import SuperRunMode
from .super_module import IntSpaceType
from .super_module import LayerOrder
from .super_module import SuperModule
from .super_container import SuperSequential

View File

@@ -37,8 +37,7 @@ class SuperTransformerEncoderLayer(SuperModule):
def __init__(
self,
input_dim: IntSpaceType,
output_dim: IntSpaceType,
d_model: IntSpaceType,
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
mlp_hidden_multiplier: IntSpaceType = 4,
@@ -48,40 +47,37 @@ class SuperTransformerEncoderLayer(SuperModule):
):
super(SuperTransformerEncoderLayer, self).__init__()
mha = SuperAttention(
input_dim,
input_dim,
d_model,
d_model,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=drop,
proj_drop=drop,
)
drop1 = nn.Dropout(drop or 0.0)
norm1 = SuperLayerNorm1D(input_dim)
mlp = SuperMLPv2(
input_dim,
d_model,
hidden_multiplier=mlp_hidden_multiplier,
out_features=output_dim,
out_features=d_model,
act_layer=act_layer,
drop=drop,
)
drop2 = nn.Dropout(drop or 0.0)
norm2 = SuperLayerNorm1D(output_dim)
if order is LayerOrder.PreNorm:
self.norm1 = norm1
self.norm1 = SuperLayerNorm1D(d_model)
self.mha = mha
self.drop1 = drop1
self.norm2 = norm2
self.drop1 = nn.Dropout(drop or 0.0)
self.norm2 = SuperLayerNorm1D(d_model)
self.mlp = mlp
self.drop2 = drop2
elif order is LayerOrder.PostNoem:
self.drop2 = nn.Dropout(drop or 0.0)
elif order is LayerOrder.PostNorm:
self.mha = mha
self.drop1 = drop1
self.norm1 = norm1
self.drop1 = nn.Dropout(drop or 0.0)
self.norm1 = SuperLayerNorm1D(d_model)
self.mlp = mlp
self.drop2 = drop2
self.norm2 = norm2
self.drop2 = nn.Dropout(drop or 0.0)
self.norm2 = SuperLayerNorm1D(d_model)
else:
raise ValueError("Unknown order: {:}".format(order))
self._order = order
@property
def abstract_search_space(self):
@@ -108,18 +104,19 @@ class SuperTransformerEncoderLayer(SuperModule):
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
if order is LayerOrder.PreNorm:
if self._order is LayerOrder.PreNorm:
x = self.norm1(input)
x = x + self.drop1(self.mha(x))
x = self.norm2(x)
x = x + self.drop2(self.mlp(x))
elif order is LayerOrder.PostNoem:
elif self._order is LayerOrder.PostNorm:
# multi-head attention
x = x + self.drop1(self.mha(input))
x = self.mha(input)
x = x + self.drop1(x)
x = self.norm1(x)
# feed-forward layer
x = x + self.drop2(self.mlp(x))
x = self.norm2(x)
else:
raise ValueError("Unknown order: {:}".format(order))
raise ValueError("Unknown order: {:}".format(self._order))
return x