Update xlayers
This commit is contained in:
@@ -31,12 +31,15 @@ class SuperSelfAttention(SuperModule):
|
||||
qkv_bias: BoolSpaceType = False,
|
||||
attn_drop: Optional[float] = None,
|
||||
proj_drop: Optional[float] = None,
|
||||
use_mask=False,
|
||||
):
|
||||
super(SuperSelfAttention, self).__init__()
|
||||
self._input_dim = input_dim
|
||||
self._proj_dim = proj_dim
|
||||
self._num_heads = num_heads
|
||||
self._qkv_bias = qkv_bias
|
||||
self._use_mask = use_mask
|
||||
self._infinity = 1e9
|
||||
|
||||
self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||
@@ -113,6 +116,12 @@ class SuperSelfAttention(SuperModule):
|
||||
.permute(0, 2, 1, 3)
|
||||
)
|
||||
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
|
||||
if self._use_mask:
|
||||
mask = torch.triu(
|
||||
torch.ones((N, N), dtype=torch.bool, device=input.device), 1
|
||||
)
|
||||
mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0)
|
||||
attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
|
||||
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * N
|
||||
attn_v1 = self.attn_drop(attn_v1)
|
||||
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
|
||||
@@ -147,8 +156,14 @@ class SuperSelfAttention(SuperModule):
|
||||
return outs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
|
||||
self._input_dim, self._proj_dim, self._num_heads
|
||||
return (
|
||||
"input_dim={:}, proj_dim={:}, num_heads={:}, mask={:}, infinity={:}".format(
|
||||
self._input_dim,
|
||||
self._proj_dim,
|
||||
self._num_heads,
|
||||
self._use_mask,
|
||||
self._infinity,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -181,6 +196,7 @@ class SuperQKVAttention(SuperModule):
|
||||
self.attn_drop = nn.Dropout(attn_drop or 0.0)
|
||||
self.proj = SuperLinear(proj_dim, proj_dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop or 0.0)
|
||||
self._infinity = 1e9
|
||||
|
||||
@property
|
||||
def num_heads(self):
|
||||
@@ -232,7 +248,9 @@ class SuperQKVAttention(SuperModule):
|
||||
if "proj" in abstract_child:
|
||||
self.proj.apply_candidate(abstract_child["proj"])
|
||||
|
||||
def forward_qkv(self, q_tensor, k_tensor, v_tensor, num_head: int) -> torch.Tensor:
|
||||
def forward_qkv(
|
||||
self, q_tensor, k_tensor, v_tensor, num_head: int, mask=None
|
||||
) -> torch.Tensor:
|
||||
q = self.q_fc(q_tensor)
|
||||
B, N, C = q.shape
|
||||
|
||||
@@ -257,6 +275,9 @@ class SuperQKVAttention(SuperModule):
|
||||
)
|
||||
# compute the attention map
|
||||
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
|
||||
if mask is not None:
|
||||
mask = torch.unsqueeze(mask, dim=1)
|
||||
attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
|
||||
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S
|
||||
attn_v1 = self.attn_drop(attn_v1)
|
||||
|
||||
@@ -281,26 +302,29 @@ class SuperQKVAttention(SuperModule):
|
||||
feats = torch.cat([feats_v1, feats_v2], dim=-1)
|
||||
return feats
|
||||
|
||||
def forward_candidate(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor:
|
||||
def forward_candidate(
|
||||
self, q_tensor, k_tensor, v_tensor, mask=None
|
||||
) -> torch.Tensor:
|
||||
# check the num_heads:
|
||||
if not spaces.is_determined(self._num_heads):
|
||||
num_heads = self.abstract_child["_num_heads"].value
|
||||
else:
|
||||
num_heads = spaces.get_determined_value(self._num_heads)
|
||||
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads)
|
||||
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads, mask)
|
||||
outs = self.proj(feats)
|
||||
outs = self.proj_drop(outs)
|
||||
return outs
|
||||
|
||||
def forward_raw(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor:
|
||||
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads)
|
||||
def forward_raw(self, q_tensor, k_tensor, v_tensor, mask=None) -> torch.Tensor:
|
||||
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads, mask)
|
||||
outs = self.proj(feats)
|
||||
outs = self.proj_drop(outs)
|
||||
return outs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
|
||||
return "input_dim={:}, proj_dim={:}, num_heads={:}, infinity={:}".format(
|
||||
(self.in_q_dim, self.in_k_dim, self.in_v_dim),
|
||||
self._proj_dim,
|
||||
self._num_heads,
|
||||
self._infinity,
|
||||
)
|
||||
|
@@ -117,16 +117,32 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
else:
|
||||
return False, self._meta_info[BEST_SCORE_KEY]
|
||||
|
||||
def load_best(self):
|
||||
if BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info:
|
||||
raise ValueError("Please call save_best at first")
|
||||
best_save_path = os.path.join(
|
||||
self._meta_info[BEST_DIR_KEY],
|
||||
"best-{:}.pth".format(self.__class__.__name__),
|
||||
)
|
||||
def load_best(self, best_save_path=None):
|
||||
if best_save_path is None:
|
||||
if (
|
||||
BEST_DIR_KEY not in self._meta_info
|
||||
or BEST_SCORE_KEY not in self._meta_info
|
||||
):
|
||||
raise ValueError("Please call save_best at first")
|
||||
best_save_name = self._meta_info.get(
|
||||
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
|
||||
)
|
||||
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
|
||||
state_dict = torch.load(best_save_path)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
def has_best(self, best_name=None):
|
||||
if BEST_DIR_KEY not in self._meta_info:
|
||||
raise ValueError("Please set BEST_DIR_KEY at first")
|
||||
if best_name is None:
|
||||
best_save_name = self._meta_info.get(
|
||||
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
|
||||
)
|
||||
else:
|
||||
best_save_name = best_name
|
||||
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
|
||||
return os.path.exists(best_save_path)
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
raise NotImplementedError
|
||||
|
@@ -45,6 +45,7 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
norm_affine: bool = True,
|
||||
act_layer: Callable[[], nn.Module] = nn.GELU,
|
||||
order: LayerOrder = LayerOrder.PreNorm,
|
||||
use_mask: bool = False,
|
||||
):
|
||||
super(SuperTransformerEncoderLayer, self).__init__()
|
||||
mha = SuperSelfAttention(
|
||||
@@ -54,6 +55,7 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=drop,
|
||||
proj_drop=drop,
|
||||
use_mask=use_mask,
|
||||
)
|
||||
mlp = SuperMLPv2(
|
||||
d_model,
|
||||
|
Reference in New Issue
Block a user