Add the SuperMLP class

This commit is contained in:
D-X-Y
2021-03-19 03:22:58 -07:00
parent 51c626c96d
commit 31b8122cc1
6 changed files with 195 additions and 53 deletions

View File

@@ -48,7 +48,7 @@ class TestBasicSpace(unittest.TestCase):
space = Continuous(lower, upper, log=False)
values = []
for i in range(1000000):
x = space.random().value
x = space.random(reuse_last=False).value
self.assertGreaterEqual(x, lower)
self.assertGreaterEqual(upper, x)
values.append(x)
@@ -97,6 +97,12 @@ class TestBasicSpace(unittest.TestCase):
self.assertTrue(is_determined(1))
self.assertFalse(is_determined(nested_space))
def test_duplicate(self):
space = Categorical(1, 2, 3, 4)
x = space.random()
for _ in range(100):
self.assertEqual(x, space.random(reuse_last=True))
class TestAbstractSpace(unittest.TestCase):
"""Test the abstract search spaces."""

View File

@@ -48,3 +48,29 @@ class TestSuperLinear(unittest.TestCase):
output_shape = (32, abstract_child["_out_features"].value)
outputs = model(inputs)
self.assertEqual(tuple(outputs.shape), output_shape)
def test_super_mlp(self):
hidden_features = spaces.Categorical(12, 24, 36)
out_features = spaces.Categorical(12, 24, 36)
mlp = super_core.SuperMLP(10, hidden_features, out_features)
print(mlp)
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
abstract_space = mlp.abstract_search_space
print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space))
self.assertEqual(
abstract_space["fc1"]["_out_features"],
abstract_space["fc2"]["_in_features"],
)
self.assertTrue(
abstract_space["fc1"]["_out_features"]
is abstract_space["fc2"]["_in_features"]
)
abstract_space.clean_last_sample()
abstract_child = abstract_space.random(reuse_last=True)
print("The abstract child program is:\n{:}".format(abstract_child))
self.assertEqual(
abstract_child["fc1"]["_out_features"].value,
abstract_child["fc2"]["_in_features"].value,
)