Refine Transformer
This commit is contained in:
@@ -21,10 +21,10 @@ import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as th_data
|
||||
|
||||
from log_utils import AverageMeter
|
||||
from utils import count_parameters
|
||||
from xautodl.xmisc import AverageMeter
|
||||
from xautodl.xmisc import count_parameters
|
||||
|
||||
from xlayers import super_core
|
||||
from xautodl.xlayers import super_core
|
||||
from .transformers import DEFAULT_NET_CONFIG
|
||||
from .transformers import get_transformer
|
||||
|
||||
|
@@ -13,7 +13,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from xautodl import spaces
|
||||
from xautodl.xlayers import trunc_normal_
|
||||
from xautodl.xlayers import weight_init
|
||||
from xautodl.xlayers import super_core
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class SuperTransformer(super_core.SuperModule):
|
||||
self.head = super_core.SuperSequential(
|
||||
super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1)
|
||||
)
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
weight_init.trunc_normal_(self.cls_token, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
@property
|
||||
@@ -136,11 +136,11 @@ class SuperTransformer(super_core.SuperModule):
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
weight_init.trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, super_core.SuperLinear):
|
||||
trunc_normal_(m._super_weight, std=0.02)
|
||||
weight_init.trunc_normal_(m._super_weight, std=0.02)
|
||||
if m._super_bias is not None:
|
||||
nn.init.constant_(m._super_bias, 0)
|
||||
elif isinstance(m, super_core.SuperLayerNorm1D):
|
||||
|
Reference in New Issue
Block a user