Update SuperViT

This commit is contained in:
D-X-Y
2021-06-09 05:39:35 -07:00
parent 0ddc5c0dc4
commit d4546cfe3f
4 changed files with 119 additions and 69 deletions

View File

@@ -13,6 +13,7 @@ from xautodl import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
from .super_dropout import SuperDropout, SuperDrop
from .super_linear import SuperLinear
@@ -22,7 +23,7 @@ class SuperSelfAttention(SuperModule):
def __init__(
self,
input_dim: IntSpaceType,
proj_dim: IntSpaceType,
proj_dim: Optional[IntSpaceType],
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
attn_drop: Optional[float] = None,
@@ -37,13 +38,17 @@ class SuperSelfAttention(SuperModule):
self._use_mask = use_mask
self._infinity = 1e9
self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
mul_head_dim = (input_dim // num_heads) * num_heads
self.q_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias)
self.k_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias)
self.v_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop or 0.0)
self.proj = SuperLinear(input_dim, proj_dim)
self.proj_drop = nn.Dropout(proj_drop or 0.0)
self.attn_drop = SuperDrop(attn_drop, [-1, -1, -1, -1], recover=True)
if proj_dim is None:
self.proj = SuperLinear(input_dim, proj_dim)
self.proj_drop = SuperDropout(proj_drop or 0.0)
else:
self.proj = None
@property
def num_heads(self):
@@ -63,7 +68,6 @@ class SuperSelfAttention(SuperModule):
space_q = self.q_fc.abstract_search_space
space_k = self.k_fc.abstract_search_space
space_v = self.v_fc.abstract_search_space
space_proj = self.proj.abstract_search_space
if not spaces.is_determined(self._num_heads):
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
if not spaces.is_determined(space_q):
@@ -72,8 +76,10 @@ class SuperSelfAttention(SuperModule):
root_node.append("k_fc", space_k)
if not spaces.is_determined(space_v):
root_node.append("v_fc", space_v)
if not spaces.is_determined(space_proj):
root_node.append("proj", space_proj)
if self.proj is not None:
space_proj = self.proj.abstract_search_space
if not spaces.is_determined(space_proj):
root_node.append("proj", space_proj)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
@@ -121,18 +127,7 @@ class SuperSelfAttention(SuperModule):
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * N
attn_v1 = self.attn_drop(attn_v1)
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
if C == head_dim * num_head:
feats = feats_v1
else: # The channels can not be divided by num_head, the remainder forms an additional head
q_v2 = q[:, :, num_head * head_dim :]
k_v2 = k[:, :, num_head * head_dim :]
v_v2 = v[:, :, num_head * head_dim :]
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
attn_v2 = attn_v2.softmax(dim=-1)
attn_v2 = self.attn_drop(attn_v2)
feats_v2 = attn_v2 @ v_v2
feats = torch.cat([feats_v1, feats_v2], dim=-1)
return feats
return feats_v1
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check the num_heads:
@@ -141,15 +136,21 @@ class SuperSelfAttention(SuperModule):
else:
num_heads = spaces.get_determined_value(self._num_heads)
feats = self.forward_qkv(input, num_heads)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
if self.proj is None:
return feats
else:
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
feats = self.forward_qkv(input, self.num_heads)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
if self.proj is None:
return feats
else:
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def extra_repr(self) -> str:
return (