Fix bugs in ViT

This commit is contained in:
D-X-Y
2021-06-09 23:08:21 +08:00
parent d4546cfe3f
commit aef5c7579b
4 changed files with 594 additions and 6 deletions

View File

@@ -38,12 +38,15 @@ class SuperSelfAttention(SuperModule):
self._use_mask = use_mask
self._infinity = 1e9
mul_head_dim = (input_dim // num_heads) * num_heads
self.q_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias)
self.k_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias)
self.v_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias)
mul_head_dim = (
spaces.get_max(input_dim) // spaces.get_min(num_heads)
) * spaces.get_min(num_heads)
assert mul_head_dim == spaces.get_max(input_dim)
self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.attn_drop = SuperDrop(attn_drop, [-1, -1, -1, -1], recover=True)
self.attn_drop = SuperDrop(attn_drop or 0.0, [-1, -1, -1, -1], recover=True)
if proj_dim is None:
self.proj = SuperLinear(input_dim, proj_dim)
self.proj_drop = SuperDropout(proj_drop or 0.0)
@@ -127,7 +130,18 @@ class SuperSelfAttention(SuperModule):
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * N
attn_v1 = self.attn_drop(attn_v1)
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
return feats_v1
if C == head_dim * num_head:
feats = feats_v1
else: # The channels can not be divided by num_head, the remainder forms an additional head
q_v2 = q[:, :, num_head * head_dim :]
k_v2 = k[:, :, num_head * head_dim :]
v_v2 = v[:, :, num_head * head_dim :]
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
attn_v2 = attn_v2.softmax(dim=-1)
attn_v2 = self.attn_drop(attn_v2)
feats_v2 = attn_v2 @ v_v2
feats = torch.cat([feats_v1, feats_v2], dim=-1)
return feats
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check the num_heads:

View File

@@ -5,3 +5,6 @@
#####################################################
from .transformers import get_transformer
def obtain_model(config):
raise NotImplementedError