Fix bugs in xlayers

This commit is contained in:
D-X-Y
2021-05-22 16:41:54 +08:00
parent 97717d826e
commit bc42ab3c08
7 changed files with 197 additions and 39 deletions

View File

@@ -20,7 +20,7 @@ from .super_module import BoolSpaceType
from .super_linear import SuperLinear
class SuperAttention(SuperModule):
class SuperSelfAttention(SuperModule):
"""The super model for attention layer."""
def __init__(
@@ -32,7 +32,7 @@ class SuperAttention(SuperModule):
attn_drop: Optional[float] = None,
proj_drop: Optional[float] = None,
):
super(SuperAttention, self).__init__()
super(SuperSelfAttention, self).__init__()
self._input_dim = input_dim
self._proj_dim = proj_dim
self._num_heads = num_heads
@@ -150,3 +150,157 @@ class SuperAttention(SuperModule):
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
self._input_dim, self._proj_dim, self._num_heads
)
class SuperQKVAttention(SuperModule):
"""The super model for attention layer."""
def __init__(
self,
in_q_dim: IntSpaceType,
in_k_dim: IntSpaceType,
in_v_dim: IntSpaceType,
proj_dim: IntSpaceType,
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
attn_drop: Optional[float] = None,
proj_drop: Optional[float] = None,
):
super(SuperQKVAttention, self).__init__()
self._in_v_dim = in_v_dim
self._in_q_dim = in_q_dim
self._in_k_dim = in_k_dim
self._proj_dim = proj_dim
self._num_heads = num_heads
self._qkv_bias = qkv_bias
self.q_fc = SuperLinear(in_q_dim, proj_dim, bias=qkv_bias)
self.k_fc = SuperLinear(in_k_dim, proj_dim, bias=qkv_bias)
self.v_fc = SuperLinear(in_v_dim, proj_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop or 0.0)
self.proj = SuperLinear(proj_dim, proj_dim)
self.proj_drop = nn.Dropout(proj_drop or 0.0)
@property
def num_heads(self):
return spaces.get_max(self._num_heads)
@property
def in_v_dim(self):
return spaces.get_max(self._in_v_dim)
@property
def in_q_dim(self):
return spaces.get_max(self._in_q_dim)
@property
def in_k_dim(self):
return spaces.get_max(self._in_k_dim)
@property
def proj_dim(self):
return spaces.get_max(self._proj_dim)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
space_q = self.q_fc.abstract_search_space
space_k = self.k_fc.abstract_search_space
space_v = self.v_fc.abstract_search_space
space_proj = self.proj.abstract_search_space
if not spaces.is_determined(self._num_heads):
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
if not spaces.is_determined(space_q):
root_node.append("q_fc", space_q)
if not spaces.is_determined(space_k):
root_node.append("k_fc", space_k)
if not spaces.is_determined(space_v):
root_node.append("v_fc", space_v)
if not spaces.is_determined(space_proj):
root_node.append("proj", space_proj)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperAttention, self).apply_candidate(abstract_child)
if "q_fc" in abstract_child:
self.q_fc.apply_candidate(abstract_child["q_fc"])
if "k_fc" in abstract_child:
self.k_fc.apply_candidate(abstract_child["k_fc"])
if "v_fc" in abstract_child:
self.v_fc.apply_candidate(abstract_child["v_fc"])
if "proj" in abstract_child:
self.proj.apply_candidate(abstract_child["proj"])
def forward_qkv(self, q_tensor, k_tensor, v_tensor, num_head: int) -> torch.Tensor:
q = self.q_fc(q_tensor)
B, N, C = q.shape
k = self.k_fc(k_tensor)
B0, S, _ = k.shape
v = self.v_fc(v_tensor)
assert B0 == v.shape[0] and S == v.shape[1]
head_dim = C // num_head
if num_head > C:
raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C))
q_v1 = (
q[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
k_v1 = (
k[:, :, : num_head * head_dim]
.reshape(B0, S, num_head, head_dim)
.permute(0, 2, 1, 3)
)
# compute the attention map
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S
attn_v1 = self.attn_drop(attn_v1)
v_v1 = (
v[:, :, : num_head * head_dim]
.reshape(B0, S, num_head, head_dim)
.permute(0, 2, 1, 3)
)
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
# process the first [num_head * head_dim] part
if C == head_dim * num_head:
feats = feats_v1
else: # The channels can not be divided by num_head, the remainder forms an additional head
# [might have bugs, did not check yet]
q_v2 = q[:, :, num_head * head_dim :]
k_v2 = k[:, :, num_head * head_dim :]
v_v2 = v[:, :, num_head * head_dim :]
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
attn_v2 = attn_v2.softmax(dim=-1)
attn_v2 = self.attn_drop(attn_v2)
feats_v2 = attn_v2 @ v_v2
feats = torch.cat([feats_v1, feats_v2], dim=-1)
return feats
def forward_candidate(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor:
# check the num_heads:
if not spaces.is_determined(self._num_heads):
num_heads = self.abstract_child["_num_heads"].value
else:
num_heads = spaces.get_determined_value(self._num_heads)
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def forward_raw(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor:
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def extra_repr(self) -> str:
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
(self.in_q_dim, self.in_k_dim, self.in_v_dim),
self._proj_dim,
self._num_heads,
)

View File

@@ -24,7 +24,8 @@ super_name2norm = {
"identity": SuperIdentity,
}
from .super_attention import SuperAttention
from .super_attention import SuperSelfAttention
from .super_attention import SuperQKVAttention
from .super_transformer import SuperTransformerEncoderLayer
from .super_activations import SuperReLU

View File

@@ -35,11 +35,13 @@ class SuperDynamicPositionE(SuperModule):
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
import pdb
pdb.set_trace()
print("---")
return F.linear(input, self._super_weight, self._super_bias)
positions = torch.unsqueeze(input * self._scale, dim=-1)
divisions = torch.reshape(
self._div_term, [1] * input.ndim + [self._div_term.numel()]
)
values = positions / divisions
embeds = torch.cat((torch.sin(values), torch.cos(values)), dim=-1)
return embeds
def extra_repr(self) -> str:
return "scale={:}, dim={:}".format(self._scale, self._dimension)

View File

@@ -19,7 +19,7 @@ from .super_module import LayerOrder
from .super_module import SuperModule
from .super_linear import SuperMLPv2
from .super_norm import SuperLayerNorm1D
from .super_attention import SuperAttention
from .super_attention import SuperSelfAttention
class SuperTransformerEncoderLayer(SuperModule):
@@ -47,7 +47,7 @@ class SuperTransformerEncoderLayer(SuperModule):
order: LayerOrder = LayerOrder.PreNorm,
):
super(SuperTransformerEncoderLayer, self).__init__()
mha = SuperAttention(
mha = SuperSelfAttention(
d_model,
d_model,
num_heads=num_heads,