Add super/norm layers in xcore

This commit is contained in:
D-X-Y
2021-05-06 16:38:58 +08:00
parent ff5e544240
commit 4c14c7b85b
6 changed files with 392 additions and 13 deletions

View File

@@ -10,21 +10,26 @@ __all__ = ["get_model"]
from xlayers.super_core import SuperSequential
from xlayers.super_core import SuperSimpleNorm
from xlayers.super_core import SuperLeakyReLU
from xlayers.super_core import SuperLinear
from xlayers.super_core import super_name2norm
from xlayers.super_core import super_name2activation
def get_model(config: Dict[Text, Any], **kwargs):
model_type = config.get("model_type", "simple_mlp")
if model_type == "simple_mlp":
act_cls = super_name2activation[kwargs["act_cls"]]
norm_cls = super_name2norm[kwargs["norm_cls"]]
mean, std = kwargs.get("mean", None), kwargs.get("std", None)
hidden_dim1 = kwargs.get("hidden_dim1", 200)
hidden_dim2 = kwargs.get("hidden_dim2", 100)
model = SuperSequential(
SuperSimpleNorm(kwargs["mean"], kwargs["std"]),
SuperLinear(kwargs["input_dim"], 200),
SuperLeakyReLU(),
SuperLinear(200, 100),
SuperLeakyReLU(),
SuperLinear(100, kwargs["output_dim"]),
norm_cls(mean=mean, std=std),
SuperLinear(kwargs["input_dim"], hidden_dim1),
act_cls(),
SuperLinear(hidden_dim1, hidden_dim2),
act_cls(),
SuperLinear(hidden_dim2, kwargs["output_dim"]),
)
else:
raise TypeError("Unkonwn model type: {:}".format(model_type))

View File

@@ -9,13 +9,27 @@ from .super_module import SuperModule
from .super_container import SuperSequential
from .super_linear import SuperLinear
from .super_linear import SuperMLPv1, SuperMLPv2
from .super_norm import SuperSimpleNorm
from .super_norm import SuperLayerNorm1D
from .super_norm import SuperSimpleLearnableNorm
from .super_norm import SuperIdentity
super_name2norm = {
"simple_norm": SuperSimpleNorm,
"simple_learn_norm": SuperSimpleLearnableNorm,
"layer_norm_1d": SuperLayerNorm1D,
"identity": SuperIdentity,
}
from .super_attention import SuperAttention
from .super_transformer import SuperTransformerEncoderLayer
from .super_activations import SuperReLU
from .super_activations import SuperLeakyReLU
super_name2activation = {"relu": SuperReLU, "leaky_relu": SuperLeakyReLU}
from .super_trade_stem import SuperAlphaEBDv1
from .super_positional_embedding import SuperPositionalEncoder

View File

@@ -30,6 +30,45 @@ class SuperRunMode(Enum):
Default = "fullmodel"
class TensorContainer:
"""A class to maintain both parameters and buffers for a model."""
def __init__(self):
self._names = []
self._tensors = []
self._param_or_buffers = []
self._name2index = dict()
def append(self, name, tensor, param_or_buffer):
if not isinstance(tensor, torch.Tensor):
raise TypeError(
"The input tensor must be torch.Tensor instead of {:}".format(
type(tensor)
)
)
self._names.append(name)
self._tensors.append(tensor)
self._param_or_buffers.append(param_or_buffer)
assert name not in self._name2index, "The [{:}] has already been added.".format(
name
)
self._name2index[name] = len(self._names) - 1
def numel(self):
total = 0
for tensor in self._tensors:
total += tensor.numel()
return total
def __len__(self):
return len(self._names)
def __repr__(self):
return "{name}({num} tensors)".format(
name=self.__class__.__name__, num=len(self)
)
class SuperModule(abc.ABC, nn.Module):
"""This class equips the nn.Module class with the ability to apply AutoDL."""
@@ -71,6 +110,14 @@ class SuperModule(abc.ABC, nn.Module):
)
self._abstract_child = abstract_child
def named_parameters_buffers(self):
container = TensorContainer()
for name, param in self.named_parameters():
container.append(name, param, True)
for name, buf in self.named_buffers():
container.append(name, buf, False)
return container
@property
def abstract_search_space(self):
raise NotImplementedError

View File

@@ -89,8 +89,8 @@ class SuperSimpleNorm(SuperModule):
def __init__(self, mean, std, inplace=False) -> None:
super(SuperSimpleNorm, self).__init__()
self._mean = mean
self._std = std
self.register_buffer("_mean", torch.tensor(mean, dtype=torch.float))
self.register_buffer("_std", torch.tensor(std, dtype=torch.float))
self._inplace = inplace
@property
@@ -111,7 +111,7 @@ class SuperSimpleNorm(SuperModule):
if (std == 0).any():
raise ValueError(
"std evaluated to zero after conversion to {}, leading to division by zero.".format(
dtype
tensor.dtype
)
)
while mean.ndim < tensor.ndim:
@@ -119,6 +119,75 @@ class SuperSimpleNorm(SuperModule):
return tensor.sub_(mean).div_(std)
def extra_repr(self) -> str:
return "mean={mean}, std={mean}, inplace={inplace}".format(
mean=self._mean, std=self._std, inplace=self._inplace
return "mean={mean}, std={std}, inplace={inplace}".format(
mean=self._mean.item(), std=self._std.item(), inplace=self._inplace
)
class SuperSimpleLearnableNorm(SuperModule):
"""Super simple normalization."""
def __init__(self, mean=0, std=1, eps=1e-6, inplace=False) -> None:
super(SuperSimpleLearnableNorm, self).__init__()
self.register_parameter(
"_mean", nn.Parameter(torch.tensor(mean, dtype=torch.float))
)
self.register_parameter(
"_std", nn.Parameter(torch.tensor(std, dtype=torch.float))
)
self._eps = eps
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check inputs ->
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
if not self._inplace:
tensor = input.clone()
else:
tensor = input
mean, std = (
self._mean.to(tensor.device),
torch.abs(self._std.to(tensor.device)) + self._eps,
)
if (std == 0).any():
raise ValueError("std leads to division by zero.")
while mean.ndim < tensor.ndim:
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
return tensor.sub_(mean).div_(std)
def extra_repr(self) -> str:
return "mean={mean}, std={std}, inplace={inplace}".format(
mean=self._mean.item(), std=self._std.item(), inplace=self._inplace
)
class SuperIdentity(SuperModule):
"""Super identity mapping layer."""
def __init__(self, inplace=False, **kwargs) -> None:
super(SuperIdentity, self).__init__()
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check inputs ->
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
if not self._inplace:
tensor = input.clone()
else:
tensor = input
return tensor
def extra_repr(self) -> str:
return "inplace={inplace}".format(inplace=self._inplace)