Refine Transformer

This commit is contained in:
D-X-Y
2021-07-04 11:59:06 +00:00
parent 9136f33684
commit 11f313288a
10 changed files with 160 additions and 28 deletions

View File

@@ -21,10 +21,10 @@ import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as th_data
from log_utils import AverageMeter
from utils import count_parameters
from xautodl.xmisc import AverageMeter
from xautodl.xmisc import count_parameters
from xlayers import super_core
from xautodl.xlayers import super_core
from .transformers import DEFAULT_NET_CONFIG
from .transformers import get_transformer

View File

@@ -13,7 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F
from xautodl import spaces
from xautodl.xlayers import trunc_normal_
from xautodl.xlayers import weight_init
from xautodl.xlayers import super_core
@@ -104,7 +104,7 @@ class SuperTransformer(super_core.SuperModule):
self.head = super_core.SuperSequential(
super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1)
)
trunc_normal_(self.cls_token, std=0.02)
weight_init.trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
@property
@@ -136,11 +136,11 @@ class SuperTransformer(super_core.SuperModule):
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
weight_init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, super_core.SuperLinear):
trunc_normal_(m._super_weight, std=0.02)
weight_init.trunc_normal_(m._super_weight, std=0.02)
if m._super_bias is not None:
nn.init.constant_(m._super_bias, 0)
elif isinstance(m, super_core.SuperLayerNorm1D):