Update SuperViT
This commit is contained in:
@@ -37,7 +37,8 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
num_heads: IntSpaceType,
|
||||
qkv_bias: BoolSpaceType = False,
|
||||
mlp_hidden_multiplier: IntSpaceType = 4,
|
||||
drop: Optional[float] = None,
|
||||
dropout: Optional[float] = None,
|
||||
att_dropout: Optional[float] = None,
|
||||
norm_affine: bool = True,
|
||||
act_layer: Callable[[], nn.Module] = nn.GELU,
|
||||
order: LayerOrder = LayerOrder.PreNorm,
|
||||
@@ -49,8 +50,8 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
d_model,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=drop,
|
||||
proj_drop=drop,
|
||||
attn_drop=att_dropout,
|
||||
proj_drop=None,
|
||||
use_mask=use_mask,
|
||||
)
|
||||
mlp = SuperMLPv2(
|
||||
@@ -58,21 +59,20 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
hidden_multiplier=mlp_hidden_multiplier,
|
||||
out_features=d_model,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=dropout,
|
||||
)
|
||||
if order is LayerOrder.PreNorm:
|
||||
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
|
||||
self.mha = mha
|
||||
self.drop1 = nn.Dropout(drop or 0.0)
|
||||
self.drop = nn.Dropout(dropout or 0.0)
|
||||
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
|
||||
self.mlp = mlp
|
||||
self.drop2 = nn.Dropout(drop or 0.0)
|
||||
elif order is LayerOrder.PostNorm:
|
||||
self.mha = mha
|
||||
self.drop1 = nn.Dropout(drop or 0.0)
|
||||
self.drop1 = nn.Dropout(dropout or 0.0)
|
||||
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
|
||||
self.mlp = mlp
|
||||
self.drop2 = nn.Dropout(drop or 0.0)
|
||||
self.drop2 = nn.Dropout(dropout or 0.0)
|
||||
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
|
||||
else:
|
||||
raise ValueError("Unknown order: {:}".format(order))
|
||||
@@ -99,23 +99,29 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
if key in abstract_child:
|
||||
getattr(self, key).apply_candidate(abstract_child[key])
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_raw(input)
|
||||
def forward_candidate(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_raw(inputs)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
def forward_raw(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
if self._order is LayerOrder.PreNorm:
|
||||
x = self.norm1(input)
|
||||
x = x + self.drop1(self.mha(x))
|
||||
x = self.norm2(x)
|
||||
x = x + self.drop2(self.mlp(x))
|
||||
# https://github.com/google-research/vision_transformer/blob/master/vit_jax/models.py#L135
|
||||
x = self.norm1(inputs)
|
||||
x = self.mha(x)
|
||||
x = self.drop(x)
|
||||
x = x + inputs
|
||||
# feed-forward layer -- MLP
|
||||
y = self.norm2(x)
|
||||
outs = x + self.mlp(y)
|
||||
elif self._order is LayerOrder.PostNorm:
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoder
|
||||
# multi-head attention
|
||||
x = self.mha(input)
|
||||
x = x + self.drop1(x)
|
||||
x = self.mha(inputs)
|
||||
x = inputs + self.drop1(x)
|
||||
x = self.norm1(x)
|
||||
# feed-forward layer
|
||||
x = x + self.drop2(self.mlp(x))
|
||||
x = self.norm2(x)
|
||||
# feed-forward layer -- MLP
|
||||
y = self.mlp(x)
|
||||
y = x + self.drop2(y)
|
||||
outs = self.norm2(y)
|
||||
else:
|
||||
raise ValueError("Unknown order: {:}".format(self._order))
|
||||
return x
|
||||
return outs
|
||||
|
Reference in New Issue
Block a user