Refine Transformer

This commit is contained in:
D-X-Y
2021-07-04 11:59:06 +00:00
parent 9136f33684
commit 11f313288a
10 changed files with 160 additions and 28 deletions

View File

@@ -4,5 +4,4 @@
# This file is expected to be self-contained, expect
# for importing from spaces to include search space.
#####################################################
from .weight_init import trunc_normal_
from .super_core import *

View File

@@ -1,8 +1,12 @@
# Borrowed from https://github.com/rwightman/pytorch-image-models
import torch
import torch.nn as nn
import math
import warnings
# setup for xlayers
from . import super_core
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
@@ -64,3 +68,17 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
return [_no_grad_trunc_normal_(x, mean, std, a, b) for x in tensor]
else:
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def init_transformer(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, super_core.SuperLinear):
trunc_normal_(m._super_weight, std=0.02)
if m._super_bias is not None:
nn.init.constant_(m._super_bias, 0)
elif isinstance(m, super_core.SuperLayerNorm1D):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)