Add unit tests for super-linear

This commit is contained in:
D-X-Y
2021-03-18 20:44:22 +08:00
parent badb6cf51d
commit ca22d61259
7 changed files with 57 additions and 17 deletions

View File

@@ -30,7 +30,7 @@ class SuperLinear(SuperModule):
self._in_features = in_features
self._out_features = out_features
self._bias = bias
# weights to be optimized
self._super_weight = torch.nn.Parameter(
torch.Tensor(self.out_features, self.in_features)
)
@@ -53,7 +53,14 @@ class SuperLinear(SuperModule):
return spaces.has_categorical(self._bias, True)
def abstract_search_space(self):
print('-')
root_node = spaces.VirtualNode(id(self))
if not spaces.is_determined(self._in_features):
root_node.append("_in_features", self._in_features)
if not spaces.is_determined(self._out_features):
root_node.append("_out_features", self._out_features)
if not spaces.is_determined(self._bias):
root_node.append("_bias", self._bias)
return root_node
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))