Add SuperSimpleNorm and update synthetic env

This commit is contained in:
D-X-Y
2021-04-23 02:12:11 -07:00
parent a5b7d986b3
commit 9b895bdf2e
13 changed files with 238 additions and 519 deletions

View File

@@ -91,6 +91,8 @@ class SuperSequential(SuperModule):
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
for index, module in enumerate(self):
if not isinstance(module, SuperModule):
continue
space = module.abstract_search_space
if not spaces.is_determined(space):
root_node.append(str(index), space)
@@ -98,9 +100,9 @@ class SuperSequential(SuperModule):
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperSequential, self).apply_candidate(abstract_child)
for index in range(len(self)):
for index, module in enumerate(self):
if str(index) in abstract_child:
self.__getitem__(index).apply_candidate(abstract_child[str(index)])
module.apply_candidate(abstract_child[str(index)])
def forward_candidate(self, input):
return self.forward_raw(input)

View File

@@ -9,6 +9,7 @@ from .super_module import SuperModule
from .super_container import SuperSequential
from .super_linear import SuperLinear
from .super_linear import SuperMLPv1, SuperMLPv2
from .super_norm import SuperSimpleNorm
from .super_norm import SuperLayerNorm1D
from .super_attention import SuperAttention
from .super_transformer import SuperTransformerEncoderLayer

View File

@@ -3,6 +3,7 @@
#####################################################
import abc
import warnings
from typing import Optional, Union, Callable
import torch
import torch.nn as nn
@@ -45,6 +46,17 @@ class SuperModule(abc.ABC, nn.Module):
self.apply(_reset_super_run)
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
if not isinstance(module, SuperModule):
warnings.warn(
"Add {:} module, which is not SuperModule, into {:}".format(
name, self.__class__.__name__
)
+ "\n"
+ "It may cause some functions invalid."
)
super(SuperModule, self).add_module(name, module)
def apply_verbose(self, verbose):
def _reset_verbose(m):
if isinstance(m, SuperModule):

View File

@@ -82,3 +82,43 @@ class SuperLayerNorm1D(SuperModule):
elementwise_affine=self._elementwise_affine,
)
)
class SuperSimpleNorm(SuperModule):
"""Super simple normalization."""
def __init__(self, mean, std, inplace=False) -> None:
super(SuperSimpleNorm, self).__init__()
self._mean = mean
self._std = std
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check inputs ->
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
if not self._inplace:
tensor = input.clone()
else:
tensor = input
mean = torch.as_tensor(self._mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(self._std, dtype=tensor.dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(
"std evaluated to zero after conversion to {}, leading to division by zero.".format(
dtype
)
)
while mean.ndim < tensor.ndim:
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
return tensor.sub_(mean).div_(std)
def extra_repr(self) -> str:
return "mean={mean}, std={mean}, inplace={inplace}".format(
mean=self._mean, std=self._std, inplace=self._inplace
)