Update Q Model
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .drop import DropBlock2d, DropPath
|
||||
from .mlp import MLP
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
from .positional_embedding import PositionalEncoder
|
||||
|
24
lib/layers/mlp.py
Normal file
24
lib/layers/mlp.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
class MLP(nn.Module):
|
||||
# MLP: FC -> Activation -> Drop -> FC -> Drop
|
||||
def __init__(self, in_features, hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer=nn.GELU,
|
||||
drop: Optional[float] = None):
|
||||
super(MLP, self).__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop or 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
Reference in New Issue
Block a user