Add SuperSimpleNorm and update synthetic env
This commit is contained in:
@@ -91,6 +91,8 @@ class SuperSequential(SuperModule):
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
for index, module in enumerate(self):
|
||||
if not isinstance(module, SuperModule):
|
||||
continue
|
||||
space = module.abstract_search_space
|
||||
if not spaces.is_determined(space):
|
||||
root_node.append(str(index), space)
|
||||
@@ -98,9 +100,9 @@ class SuperSequential(SuperModule):
|
||||
|
||||
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||
super(SuperSequential, self).apply_candidate(abstract_child)
|
||||
for index in range(len(self)):
|
||||
for index, module in enumerate(self):
|
||||
if str(index) in abstract_child:
|
||||
self.__getitem__(index).apply_candidate(abstract_child[str(index)])
|
||||
module.apply_candidate(abstract_child[str(index)])
|
||||
|
||||
def forward_candidate(self, input):
|
||||
return self.forward_raw(input)
|
||||
|
@@ -9,6 +9,7 @@ from .super_module import SuperModule
|
||||
from .super_container import SuperSequential
|
||||
from .super_linear import SuperLinear
|
||||
from .super_linear import SuperMLPv1, SuperMLPv2
|
||||
from .super_norm import SuperSimpleNorm
|
||||
from .super_norm import SuperLayerNorm1D
|
||||
from .super_attention import SuperAttention
|
||||
from .super_transformer import SuperTransformerEncoderLayer
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#####################################################
|
||||
|
||||
import abc
|
||||
import warnings
|
||||
from typing import Optional, Union, Callable
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -45,6 +46,17 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
|
||||
self.apply(_reset_super_run)
|
||||
|
||||
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
|
||||
if not isinstance(module, SuperModule):
|
||||
warnings.warn(
|
||||
"Add {:} module, which is not SuperModule, into {:}".format(
|
||||
name, self.__class__.__name__
|
||||
)
|
||||
+ "\n"
|
||||
+ "It may cause some functions invalid."
|
||||
)
|
||||
super(SuperModule, self).add_module(name, module)
|
||||
|
||||
def apply_verbose(self, verbose):
|
||||
def _reset_verbose(m):
|
||||
if isinstance(m, SuperModule):
|
||||
|
@@ -82,3 +82,43 @@ class SuperLayerNorm1D(SuperModule):
|
||||
elementwise_affine=self._elementwise_affine,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SuperSimpleNorm(SuperModule):
|
||||
"""Super simple normalization."""
|
||||
|
||||
def __init__(self, mean, std, inplace=False) -> None:
|
||||
super(SuperSimpleNorm, self).__init__()
|
||||
self._mean = mean
|
||||
self._std = std
|
||||
self._inplace = inplace
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
return spaces.VirtualNode(id(self))
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# check inputs ->
|
||||
return self.forward_raw(input)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if not self._inplace:
|
||||
tensor = input.clone()
|
||||
else:
|
||||
tensor = input
|
||||
mean = torch.as_tensor(self._mean, dtype=tensor.dtype, device=tensor.device)
|
||||
std = torch.as_tensor(self._std, dtype=tensor.dtype, device=tensor.device)
|
||||
if (std == 0).any():
|
||||
raise ValueError(
|
||||
"std evaluated to zero after conversion to {}, leading to division by zero.".format(
|
||||
dtype
|
||||
)
|
||||
)
|
||||
while mean.ndim < tensor.ndim:
|
||||
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
|
||||
return tensor.sub_(mean).div_(std)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "mean={mean}, std={mean}, inplace={inplace}".format(
|
||||
mean=self._mean, std=self._std, inplace=self._inplace
|
||||
)
|
||||
|
Reference in New Issue
Block a user