Refine Transformer
This commit is contained in:
@@ -4,5 +4,4 @@
|
||||
# This file is expected to be self-contained, expect
|
||||
# for importing from spaces to include search space.
|
||||
#####################################################
|
||||
from .weight_init import trunc_normal_
|
||||
from .super_core import *
|
||||
|
@@ -1,8 +1,12 @@
|
||||
# Borrowed from https://github.com/rwightman/pytorch-image-models
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import warnings
|
||||
|
||||
# setup for xlayers
|
||||
from . import super_core
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
@@ -64,3 +68,17 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
||||
return [_no_grad_trunc_normal_(x, mean, std, a, b) for x in tensor]
|
||||
else:
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def init_transformer(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, super_core.SuperLinear):
|
||||
trunc_normal_(m._super_weight, std=0.02)
|
||||
if m._super_bias is not None:
|
||||
nn.init.constant_(m._super_bias, 0)
|
||||
elif isinstance(m, super_core.SuperLayerNorm1D):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
Reference in New Issue
Block a user