Refine Transformer
This commit is contained in:
@@ -20,20 +20,6 @@ def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
weight_init.trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, xlayers.SuperLinear):
|
||||
weight_init.trunc_normal_(m._super_weight, std=0.02)
|
||||
if m._super_bias is not None:
|
||||
nn.init.constant_(m._super_bias, 0)
|
||||
elif isinstance(m, xlayers.SuperLayerNorm1D):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
name2config = {
|
||||
"vit-cifar10-p4-d4-h4-c32": dict(
|
||||
type="vit",
|
||||
@@ -155,7 +141,7 @@ class SuperViT(xlayers.SuperModule):
|
||||
)
|
||||
|
||||
weight_init.trunc_normal_(self.cls_token, std=0.02)
|
||||
self.apply(_init_weights)
|
||||
self.apply(weight_init.init_transformer)
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
|
Reference in New Issue
Block a user