Add the SuperMLP class
This commit is contained in:
@@ -3,4 +3,5 @@
|
||||
#####################################################
|
||||
from .super_module import SuperRunMode
|
||||
from .super_module import SuperModule
|
||||
from .super_mlp import SuperLinear
|
||||
from .super_linear import SuperLinear
|
||||
from .super_linear import SuperMLP
|
||||
|
@@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Callable
|
||||
|
||||
import spaces
|
||||
from .super_module import SuperModule
|
||||
@@ -57,11 +57,15 @@ class SuperLinear(SuperModule):
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
if not spaces.is_determined(self._in_features):
|
||||
root_node.append("_in_features", self._in_features.abstract())
|
||||
root_node.append(
|
||||
"_in_features", self._in_features.abstract(reuse_last=True)
|
||||
)
|
||||
if not spaces.is_determined(self._out_features):
|
||||
root_node.append("_out_features", self._out_features.abstract())
|
||||
root_node.append(
|
||||
"_out_features", self._out_features.abstract(reuse_last=True)
|
||||
)
|
||||
if not spaces.is_determined(self._bias):
|
||||
root_node.append("_bias", self._bias.abstract())
|
||||
root_node.append("_bias", self._bias.abstract(reuse_last=True))
|
||||
return root_node
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
@@ -116,24 +120,51 @@ class SuperMLP(SuperModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer=nn.GELU,
|
||||
in_features: IntSpaceType,
|
||||
hidden_features: IntSpaceType,
|
||||
out_features: IntSpaceType,
|
||||
act_layer: Callable[[], nn.Module] = nn.GELU,
|
||||
drop: Optional[float] = None,
|
||||
):
|
||||
super(SuperMLP, self).__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self._in_features = in_features
|
||||
self._hidden_features = hidden_features
|
||||
self._out_features = out_features
|
||||
self._drop_rate = drop
|
||||
self.fc1 = SuperLinear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.fc2 = SuperLinear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop or 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
space_fc1 = self.fc1.abstract_search_space
|
||||
space_fc2 = self.fc2.abstract_search_space
|
||||
if not spaces.is_determined(space_fc1):
|
||||
root_node.append("fc1", space_fc1)
|
||||
if not spaces.is_determined(space_fc2):
|
||||
root_node.append("fc2", space_fc2)
|
||||
return root_node
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self._unified_forward(x)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self._unified_forward(x)
|
||||
|
||||
def _unified_forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format(
|
||||
self._in_features,
|
||||
self._hidden_features,
|
||||
self._out_features,
|
||||
self._drop_rate,
|
||||
)
|
Reference in New Issue
Block a user