Add SuperSimpleNorm and update synthetic env
This commit is contained in:
@@ -82,3 +82,43 @@ class SuperLayerNorm1D(SuperModule):
|
||||
elementwise_affine=self._elementwise_affine,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SuperSimpleNorm(SuperModule):
|
||||
"""Super simple normalization."""
|
||||
|
||||
def __init__(self, mean, std, inplace=False) -> None:
|
||||
super(SuperSimpleNorm, self).__init__()
|
||||
self._mean = mean
|
||||
self._std = std
|
||||
self._inplace = inplace
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
return spaces.VirtualNode(id(self))
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# check inputs ->
|
||||
return self.forward_raw(input)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if not self._inplace:
|
||||
tensor = input.clone()
|
||||
else:
|
||||
tensor = input
|
||||
mean = torch.as_tensor(self._mean, dtype=tensor.dtype, device=tensor.device)
|
||||
std = torch.as_tensor(self._std, dtype=tensor.dtype, device=tensor.device)
|
||||
if (std == 0).any():
|
||||
raise ValueError(
|
||||
"std evaluated to zero after conversion to {}, leading to division by zero.".format(
|
||||
dtype
|
||||
)
|
||||
)
|
||||
while mean.ndim < tensor.ndim:
|
||||
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
|
||||
return tensor.sub_(mean).div_(std)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "mean={mean}, std={mean}, inplace={inplace}".format(
|
||||
mean=self._mean, std=self._std, inplace=self._inplace
|
||||
)
|
||||
|
Reference in New Issue
Block a user