Update LFNA

This commit is contained in:
D-X-Y
2021-05-12 20:32:50 +08:00
parent 06f4a1f1cf
commit 0b1ca45c44
8 changed files with 121 additions and 15 deletions

View File

@@ -11,6 +11,7 @@ __all__ = ["get_model"]
from xlayers.super_core import SuperSequential
from xlayers.super_core import SuperLinear
from xlayers.super_core import SuperDropout
from xlayers.super_core import super_name2norm
from xlayers.super_core import super_name2activation
@@ -47,7 +48,20 @@ def get_model(config: Dict[Text, Any], **kwargs):
last_dim = hidden_dim
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
model = SuperSequential(*sub_layers)
elif model_type == "dual_norm_mlp":
act_cls = super_name2activation[kwargs["act_cls"]]
norm_cls = super_name2norm[kwargs["norm_cls"]]
sub_layers, last_dim = [], kwargs["input_dim"]
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
if i > 0:
sub_layers.append(norm_cls(last_dim, elementwise_affine=False))
sub_layers.append(SuperLinear(last_dim, hidden_dim))
sub_layers.append(SuperDropout(kwargs["dropout"]))
sub_layers.append(SuperLinear(hidden_dim, hidden_dim))
sub_layers.append(act_cls())
last_dim = hidden_dim
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
model = SuperSequential(*sub_layers)
else:
raise TypeError("Unkonwn model type: {:}".format(model_type))
return model

View File

@@ -14,6 +14,7 @@ from .super_norm import SuperSimpleNorm
from .super_norm import SuperLayerNorm1D
from .super_norm import SuperSimpleLearnableNorm
from .super_norm import SuperIdentity
from .super_dropout import SuperDropout
super_name2norm = {
"simple_norm": SuperSimpleNorm,

View File

@@ -0,0 +1,40 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Callable
import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
class SuperDropout(SuperModule):
"""Applies a the dropout function element-wise."""
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super(SuperDropout, self).__init__()
self._p = p
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.dropout(input, self._p, self.training, self._inplace)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)
def extra_repr(self) -> str:
xstr = "inplace=True" if self._inplace else ""
return "p={:}".format(self._p) + ", " + xstr

View File

@@ -74,6 +74,19 @@ class SuperLayerNorm1D(SuperModule):
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps)
def forward_with_container(self, input, container, prefix=[]):
super_weight_name = ".".join(prefix + ["weight"])
if container.has(super_weight_name):
weight = container.query(super_weight_name)
else:
weight = None
super_bias_name = ".".join(prefix + ["bias"])
if container.has(super_bias_name):
bias = container.query(super_bias_name)
else:
bias = None
return F.layer_norm(input, (self.in_dim,), weight, bias, self.eps)
def extra_repr(self) -> str:
return (
"shape={in_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(