Add SuperTransformerEncoder

This commit is contained in:
D-X-Y
2021-03-20 22:28:23 +08:00
parent e023a53c75
commit 32900797eb
11 changed files with 524 additions and 125 deletions

View File

@@ -29,8 +29,8 @@ class SuperAttention(SuperModule):
proj_dim: IntSpaceType,
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
attn_drop: Optional[float] = None,
proj_drop: Optional[float] = None,
):
super(SuperAttention, self).__init__()
self._input_dim = input_dim
@@ -45,9 +45,9 @@ class SuperAttention(SuperModule):
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.attn_drop = nn.Dropout(attn_drop or 0.0)
self.proj = SuperLinear(input_dim, proj_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.proj_drop = nn.Dropout(proj_drop or 0.0)
@property
def num_heads(self):