Update xlayers

This commit is contained in:
D-X-Y
2021-05-07 10:26:35 +08:00
parent f6a024a6ff
commit 80aaac4dfa
9 changed files with 333 additions and 83 deletions

View File

@@ -31,6 +31,9 @@ class SuperReLU(SuperModule):
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.relu(input, inplace=self._inplace)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)
def extra_repr(self) -> str:
return "inplace=True" if self._inplace else ""
@@ -53,6 +56,29 @@ class SuperLeakyReLU(SuperModule):
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.leaky_relu(input, self._negative_slope, self._inplace)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)
def extra_repr(self) -> str:
inplace_str = "inplace=True" if self._inplace else ""
return "negative_slope={}{}".format(self._negative_slope, inplace_str)
class SuperTanh(SuperModule):
"""Applies a the Tanh function element-wise."""
def __init__(self) -> None:
super(SuperTanh, self).__init__()
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return torch.tanh(input)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)