Add the SuperMLP class

This commit is contained in:
D-X-Y
2021-03-19 03:22:58 -07:00
parent 51c626c96d
commit 31b8122cc1
6 changed files with 195 additions and 53 deletions

View File

@@ -48,3 +48,29 @@ class TestSuperLinear(unittest.TestCase):
output_shape = (32, abstract_child["_out_features"].value)
outputs = model(inputs)
self.assertEqual(tuple(outputs.shape), output_shape)
def test_super_mlp(self):
hidden_features = spaces.Categorical(12, 24, 36)
out_features = spaces.Categorical(12, 24, 36)
mlp = super_core.SuperMLP(10, hidden_features, out_features)
print(mlp)
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
abstract_space = mlp.abstract_search_space
print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space))
self.assertEqual(
abstract_space["fc1"]["_out_features"],
abstract_space["fc2"]["_in_features"],
)
self.assertTrue(
abstract_space["fc1"]["_out_features"]
is abstract_space["fc2"]["_in_features"]
)
abstract_space.clean_last_sample()
abstract_child = abstract_space.random(reuse_last=True)
print("The abstract child program is:\n{:}".format(abstract_child))
self.assertEqual(
abstract_child["fc1"]["_out_features"].value,
abstract_child["fc2"]["_in_features"].value,
)