Update SuperViT

This commit is contained in:
D-X-Y
2021-06-09 05:39:35 -07:00
parent 0ddc5c0dc4
commit d4546cfe3f
4 changed files with 119 additions and 69 deletions

View File

@@ -37,7 +37,8 @@ class SuperTransformerEncoderLayer(SuperModule):
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
mlp_hidden_multiplier: IntSpaceType = 4,
drop: Optional[float] = None,
dropout: Optional[float] = None,
att_dropout: Optional[float] = None,
norm_affine: bool = True,
act_layer: Callable[[], nn.Module] = nn.GELU,
order: LayerOrder = LayerOrder.PreNorm,
@@ -49,8 +50,8 @@ class SuperTransformerEncoderLayer(SuperModule):
d_model,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=drop,
proj_drop=drop,
attn_drop=att_dropout,
proj_drop=None,
use_mask=use_mask,
)
mlp = SuperMLPv2(
@@ -58,21 +59,20 @@ class SuperTransformerEncoderLayer(SuperModule):
hidden_multiplier=mlp_hidden_multiplier,
out_features=d_model,
act_layer=act_layer,
drop=drop,
drop=dropout,
)
if order is LayerOrder.PreNorm:
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mha = mha
self.drop1 = nn.Dropout(drop or 0.0)
self.drop = nn.Dropout(dropout or 0.0)
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mlp = mlp
self.drop2 = nn.Dropout(drop or 0.0)
elif order is LayerOrder.PostNorm:
self.mha = mha
self.drop1 = nn.Dropout(drop or 0.0)
self.drop1 = nn.Dropout(dropout or 0.0)
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mlp = mlp
self.drop2 = nn.Dropout(drop or 0.0)
self.drop2 = nn.Dropout(dropout or 0.0)
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
else:
raise ValueError("Unknown order: {:}".format(order))
@@ -99,23 +99,29 @@ class SuperTransformerEncoderLayer(SuperModule):
if key in abstract_child:
getattr(self, key).apply_candidate(abstract_child[key])
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_candidate(self, inputs: torch.Tensor) -> torch.Tensor:
return self.forward_raw(inputs)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
def forward_raw(self, inputs: torch.Tensor) -> torch.Tensor:
if self._order is LayerOrder.PreNorm:
x = self.norm1(input)
x = x + self.drop1(self.mha(x))
x = self.norm2(x)
x = x + self.drop2(self.mlp(x))
# https://github.com/google-research/vision_transformer/blob/master/vit_jax/models.py#L135
x = self.norm1(inputs)
x = self.mha(x)
x = self.drop(x)
x = x + inputs
# feed-forward layer -- MLP
y = self.norm2(x)
outs = x + self.mlp(y)
elif self._order is LayerOrder.PostNorm:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoder
# multi-head attention
x = self.mha(input)
x = x + self.drop1(x)
x = self.mha(inputs)
x = inputs + self.drop1(x)
x = self.norm1(x)
# feed-forward layer
x = x + self.drop2(self.mlp(x))
x = self.norm2(x)
# feed-forward layer -- MLP
y = self.mlp(x)
y = x + self.drop2(y)
outs = self.norm2(y)
else:
raise ValueError("Unknown order: {:}".format(self._order))
return x
return outs