Add the SuperMLP class

This commit is contained in:
D-X-Y
2021-03-19 03:22:58 -07:00
parent 51c626c96d
commit 31b8122cc1
6 changed files with 195 additions and 53 deletions

View File

@@ -3,4 +3,5 @@
#####################################################
from .super_module import SuperRunMode
from .super_module import SuperModule
from .super_mlp import SuperLinear
from .super_linear import SuperLinear
from .super_linear import SuperMLP

View File

@@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Union
from typing import Optional, Union, Callable
import spaces
from .super_module import SuperModule
@@ -57,11 +57,15 @@ class SuperLinear(SuperModule):
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
if not spaces.is_determined(self._in_features):
root_node.append("_in_features", self._in_features.abstract())
root_node.append(
"_in_features", self._in_features.abstract(reuse_last=True)
)
if not spaces.is_determined(self._out_features):
root_node.append("_out_features", self._out_features.abstract())
root_node.append(
"_out_features", self._out_features.abstract(reuse_last=True)
)
if not spaces.is_determined(self._bias):
root_node.append("_bias", self._bias.abstract())
root_node.append("_bias", self._bias.abstract(reuse_last=True))
return root_node
def reset_parameters(self) -> None:
@@ -116,24 +120,51 @@ class SuperMLP(SuperModule):
def __init__(
self,
in_features,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer=nn.GELU,
in_features: IntSpaceType,
hidden_features: IntSpaceType,
out_features: IntSpaceType,
act_layer: Callable[[], nn.Module] = nn.GELU,
drop: Optional[float] = None,
):
super(SuperMLP, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self._in_features = in_features
self._hidden_features = hidden_features
self._out_features = out_features
self._drop_rate = drop
self.fc1 = SuperLinear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.fc2 = SuperLinear(hidden_features, out_features)
self.drop = nn.Dropout(drop or 0.0)
def forward(self, x):
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
space_fc1 = self.fc1.abstract_search_space
space_fc2 = self.fc2.abstract_search_space
if not spaces.is_determined(space_fc1):
root_node.append("fc1", space_fc1)
if not spaces.is_determined(space_fc2):
root_node.append("fc2", space_fc2)
return root_node
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self._unified_forward(x)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return self._unified_forward(x)
def _unified_forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def extra_repr(self) -> str:
return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format(
self._in_features,
self._hidden_features,
self._out_features,
self._drop_rate,
)