Update baseline.sh

This commit is contained in:
D-X-Y
2021-04-02 00:40:26 -07:00
parent 0e35d8b156
commit 1dd665ae06
2 changed files with 3 additions and 4 deletions

View File

@@ -113,7 +113,7 @@ class SuperAttention(SuperModule):
.permute(0, 2, 1, 3)
)
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
attn_v1 = attn_v1.softmax(dim=-1)
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * N
attn_v1 = self.attn_drop(attn_v1)
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
if C == head_dim * num_head: