Add unit tests for super-linear
This commit is contained in:
@@ -30,7 +30,7 @@ class SuperLinear(SuperModule):
|
||||
self._in_features = in_features
|
||||
self._out_features = out_features
|
||||
self._bias = bias
|
||||
|
||||
# weights to be optimized
|
||||
self._super_weight = torch.nn.Parameter(
|
||||
torch.Tensor(self.out_features, self.in_features)
|
||||
)
|
||||
@@ -53,7 +53,14 @@ class SuperLinear(SuperModule):
|
||||
return spaces.has_categorical(self._bias, True)
|
||||
|
||||
def abstract_search_space(self):
|
||||
print('-')
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
if not spaces.is_determined(self._in_features):
|
||||
root_node.append("_in_features", self._in_features)
|
||||
if not spaces.is_determined(self._out_features):
|
||||
root_node.append("_out_features", self._out_features)
|
||||
if not spaces.is_determined(self._bias):
|
||||
root_node.append("_bias", self._bias)
|
||||
return root_node
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))
|
||||
|
Reference in New Issue
Block a user