Update LFNA
This commit is contained in:
@@ -11,6 +11,7 @@ __all__ = ["get_model"]
|
||||
|
||||
from xlayers.super_core import SuperSequential
|
||||
from xlayers.super_core import SuperLinear
|
||||
from xlayers.super_core import SuperDropout
|
||||
from xlayers.super_core import super_name2norm
|
||||
from xlayers.super_core import super_name2activation
|
||||
|
||||
@@ -47,7 +48,20 @@ def get_model(config: Dict[Text, Any], **kwargs):
|
||||
last_dim = hidden_dim
|
||||
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
|
||||
model = SuperSequential(*sub_layers)
|
||||
|
||||
elif model_type == "dual_norm_mlp":
|
||||
act_cls = super_name2activation[kwargs["act_cls"]]
|
||||
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
||||
sub_layers, last_dim = [], kwargs["input_dim"]
|
||||
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
||||
if i > 0:
|
||||
sub_layers.append(norm_cls(last_dim, elementwise_affine=False))
|
||||
sub_layers.append(SuperLinear(last_dim, hidden_dim))
|
||||
sub_layers.append(SuperDropout(kwargs["dropout"]))
|
||||
sub_layers.append(SuperLinear(hidden_dim, hidden_dim))
|
||||
sub_layers.append(act_cls())
|
||||
last_dim = hidden_dim
|
||||
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
|
||||
model = SuperSequential(*sub_layers)
|
||||
else:
|
||||
raise TypeError("Unkonwn model type: {:}".format(model_type))
|
||||
return model
|
||||
|
@@ -14,6 +14,7 @@ from .super_norm import SuperSimpleNorm
|
||||
from .super_norm import SuperLayerNorm1D
|
||||
from .super_norm import SuperSimpleLearnableNorm
|
||||
from .super_norm import SuperIdentity
|
||||
from .super_dropout import SuperDropout
|
||||
|
||||
super_name2norm = {
|
||||
"simple_norm": SuperSimpleNorm,
|
||||
|
40
lib/xlayers/super_dropout.py
Normal file
40
lib/xlayers/super_dropout.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
from typing import Optional, Callable
|
||||
|
||||
import spaces
|
||||
from .super_module import SuperModule
|
||||
from .super_module import IntSpaceType
|
||||
from .super_module import BoolSpaceType
|
||||
|
||||
|
||||
class SuperDropout(SuperModule):
|
||||
"""Applies a the dropout function element-wise."""
|
||||
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
||||
super(SuperDropout, self).__init__()
|
||||
self._p = p
|
||||
self._inplace = inplace
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
return spaces.VirtualNode(id(self))
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_raw(input)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.dropout(input, self._p, self.training, self._inplace)
|
||||
|
||||
def forward_with_container(self, input, container, prefix=[]):
|
||||
return self.forward_raw(input)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
xstr = "inplace=True" if self._inplace else ""
|
||||
return "p={:}".format(self._p) + ", " + xstr
|
@@ -74,6 +74,19 @@ class SuperLayerNorm1D(SuperModule):
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps)
|
||||
|
||||
def forward_with_container(self, input, container, prefix=[]):
|
||||
super_weight_name = ".".join(prefix + ["weight"])
|
||||
if container.has(super_weight_name):
|
||||
weight = container.query(super_weight_name)
|
||||
else:
|
||||
weight = None
|
||||
super_bias_name = ".".join(prefix + ["bias"])
|
||||
if container.has(super_bias_name):
|
||||
bias = container.query(super_bias_name)
|
||||
else:
|
||||
bias = None
|
||||
return F.layer_norm(input, (self.in_dim,), weight, bias, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"shape={in_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(
|
||||
|
Reference in New Issue
Block a user