Refine Transformer

This commit is contained in:
D-X-Y
2021-07-04 11:59:06 +00:00
parent 9136f33684
commit 11f313288a
10 changed files with 160 additions and 28 deletions

View File

@@ -20,20 +20,6 @@ def pair(t):
return t if isinstance(t, tuple) else (t, t)
def _init_weights(m):
if isinstance(m, nn.Linear):
weight_init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, xlayers.SuperLinear):
weight_init.trunc_normal_(m._super_weight, std=0.02)
if m._super_bias is not None:
nn.init.constant_(m._super_bias, 0)
elif isinstance(m, xlayers.SuperLayerNorm1D):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
name2config = {
"vit-cifar10-p4-d4-h4-c32": dict(
type="vit",
@@ -155,7 +141,7 @@ class SuperViT(xlayers.SuperModule):
)
weight_init.trunc_normal_(self.cls_token, std=0.02)
self.apply(_init_weights)
self.apply(weight_init.init_transformer)
@property
def abstract_search_space(self):