Add SuperTransformer

This commit is contained in:
D-X-Y
2021-03-21 20:52:22 +08:00
parent 033878becb
commit b8c173eb76
12 changed files with 355 additions and 204 deletions

View File

@@ -6,236 +6,186 @@ from __future__ import print_function
import math
from functools import partial
from typing import Optional, Text
from typing import Optional, Text, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import xlayers
import spaces
from xlayers import trunc_normal_
from xlayers import super_core
DEFAULT_NET_CONFIG = dict(
__all__ = ["DefaultSearchSpace"]
def _get_mul_specs(candidates, num):
results = []
for i in range(num):
results.append(spaces.Categorical(*candidates))
return results
def _get_list_mul(num, multipler):
results = []
for i in range(1, num + 1):
results.append(i * multipler)
return results
def _assert_types(x, expected_types):
if not isinstance(x, expected_types):
raise TypeError(
"The type [{:}] is expected to be {:}.".format(type(x), expected_types)
)
_default_max_depth = 5
DefaultSearchSpace = dict(
d_feat=6,
embed_dim=64,
depth=5,
num_heads=4,
mlp_ratio=4.0,
stem_dim=spaces.Categorical(*_get_list_mul(8, 16)),
embed_dims=_get_mul_specs(_get_list_mul(8, 16), _default_max_depth),
num_heads=_get_mul_specs((1, 2, 4, 8), _default_max_depth),
mlp_hidden_multipliers=_get_mul_specs((0.5, 1, 2, 4, 8), _default_max_depth),
qkv_bias=True,
pos_drop=0.0,
mlp_drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
other_drop=0.0,
)
# Real Model
class SuperTransformer(super_core.SuperModule):
"""The super model for transformer."""
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or math.sqrt(head_dim)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
mlp_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super(Block, self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=mlp_drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = (
xlayers.DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = xlayers.MLP(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=mlp_drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class SimpleEmbed(nn.Module):
def __init__(self, d_feat, embed_dim):
super(SimpleEmbed, self).__init__()
self.d_feat = d_feat
self.embed_dim = embed_dim
self.proj = nn.Linear(d_feat, embed_dim)
def forward(self, x):
x = x.reshape(len(x), self.d_feat, -1) # [N, F*T] -> [N, F, T]
x = x.permute(0, 2, 1) # [N, F, T] -> [N, T, F]
out = self.proj(x) * math.sqrt(self.embed_dim)
return out
class TransformerModel(nn.Module):
def __init__(
self,
d_feat: int = 6,
embed_dim: int = 64,
depth: int = 4,
num_heads: int = 4,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
pos_drop: float = 0.0,
mlp_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Optional[nn.Module] = None,
stem_dim: super_core.IntSpaceType = DefaultSearchSpace["stem_dim"],
embed_dims: List[super_core.IntSpaceType] = DefaultSearchSpace["embed_dims"],
num_heads: List[super_core.IntSpaceType] = DefaultSearchSpace["num_heads"],
mlp_hidden_multipliers: List[super_core.IntSpaceType] = DefaultSearchSpace[
"mlp_hidden_multipliers"
],
qkv_bias: bool = DefaultSearchSpace["qkv_bias"],
pos_drop: float = DefaultSearchSpace["pos_drop"],
other_drop: float = DefaultSearchSpace["other_drop"],
max_seq_len: int = 65,
):
"""
Args:
d_feat (int, tuple): input image size
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
pos_drop (float): dropout rate for the positional embedding
mlp_drop_rate (float): the dropout rate for MLP layers in a block
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
"""
super(TransformerModel, self).__init__()
self.embed_dim = embed_dim
self.num_features = embed_dim
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
super(SuperTransformer, self).__init__()
self._embed_dims = embed_dims
self._stem_dim = stem_dim
self._num_heads = num_heads
self._mlp_hidden_multipliers = mlp_hidden_multipliers
self.input_embed = SimpleEmbed(d_feat, embed_dim=embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = xlayers.PositionalEncoder(
d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop
# the stem part
self.input_embed = super_core.SuperAlphaEBDv1(d_feat, stem_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.stem_dim))
self.pos_embed = super_core.SuperPositionalEncoder(
d_model=stem_dim, max_seq_len=max_seq_len, dropout=pos_drop
)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop_rate,
mlp_drop=mlp_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
)
for i in range(depth)
]
# build the transformer encode layers -->> check params
_assert_types(embed_dims, (tuple, list))
_assert_types(num_heads, (tuple, list))
_assert_types(mlp_hidden_multipliers, (tuple, list))
num_layers = len(embed_dims)
assert (
num_layers == len(num_heads) == len(mlp_hidden_multipliers)
), "{:} vs {:} vs {:}".format(
num_layers, len(num_heads), len(mlp_hidden_multipliers)
)
self.norm = norm_layer(embed_dim)
# build the transformer encode layers -->> backbone
layers, input_dim = [], stem_dim
for embed_dim, num_head, mlp_hidden_multiplier in zip(
embed_dims, num_heads, mlp_hidden_multipliers
):
layer = super_core.SuperTransformerEncoderLayer(
input_dim,
embed_dim,
num_head,
qkv_bias,
mlp_hidden_multiplier,
other_drop,
)
layers.append(layer)
input_dim = embed_dim
self.backbone = super_core.SuperSequential(*layers)
# regression head
self.head = nn.Linear(self.num_features, 1)
xlayers.trunc_normal_(self.cls_token, std=0.02)
# the regression head
self.head = super_core.SuperLinear(self._embed_dims[-1], 1)
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
@property
def stem_dim(self):
return spaces.get_max(self._stem_dim)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
xdict = dict(
input_embed=self.input_embed.abstract_search_space,
pos_embed=self.pos_embed.abstract_search_space,
backbone=self.backbone.abstract_search_space,
head=self.head.abstract_search_space,
)
if not spaces.is_determined(self._stem_dim):
root_node.append("_stem_dim", self._stem_dim.abstract(reuse_last=True))
for key, space in xdict.items():
if not spaces.is_determined(space):
root_node.append(key, space)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperTransformer, self).apply_candidate(abstract_child)
xkeys = ("input_embed", "pos_embed", "backbone", "head")
for key in xkeys:
if key in abstract_child:
getattr(self, key).apply_candidate(abstract_child[key])
def _init_weights(self, m):
if isinstance(m, nn.Linear):
xlayers.trunc_normal_(m.weight, std=0.02)
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
elif isinstance(m, super_core.SuperLinear):
trunc_normal_(m._super_weight, std=0.02)
if m._super_bias is not None:
nn.init.constant_(m._super_bias, 0)
elif isinstance(m, super_core.SuperLayerNorm1D):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
batch, flatten_size = x.shape
feats = self.input_embed(x) # batch * 60 * 64
cls_tokens = self.cls_token.expand(
batch, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
batch, flatten_size = input.shape
feats = self.input_embed(input) # batch * 60 * 64
if not spaces.is_determined(self._stem_dim):
stem_dim = self.abstract_child["_stem_dim"].value
else:
stem_dim = spaces.get_determined_value(self._stem_dim)
cls_tokens = self.cls_token.expand(batch, -1, -1)
cls_tokens = F.interpolate(cls_tokens, size=(stem_dim), mode="linear", align_corners=True)
feats_w_ct = torch.cat((cls_tokens, feats), dim=1)
feats_w_tp = self.pos_embed(feats_w_ct)
xfeats = self.backbone(feats_w_tp)
xfeats = xfeats[:, 0, :] # use the feature for the first token
predicts = self.head(xfeats).squeeze(-1)
return predicts
xfeats = feats_w_tp
for block in self.blocks:
xfeats = block(xfeats)
xfeats = self.norm(xfeats)[:, 0]
return xfeats
def forward(self, x):
feats = self.forward_features(x)
predicts = self.head(feats).squeeze(-1)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
batch, flatten_size = input.shape
feats = self.input_embed(input) # batch * 60 * 64
cls_tokens = self.cls_token.expand(batch, -1, -1)
feats_w_ct = torch.cat((cls_tokens, feats), dim=1)
feats_w_tp = self.pos_embed(feats_w_ct)
xfeats = self.backbone(feats_w_tp)
xfeats = xfeats[:, 0, :] # use the feature for the first token
predicts = self.head(xfeats).squeeze(-1)
return predicts
def get_transformer(config):
if config is None:
return SuperTransformer(6)
if not isinstance(config, dict):
raise ValueError("Invalid Configuration: {:}".format(config))
name = config.get("name", "basic")