This commit is contained in:
D-X-Y
2021-03-24 05:33:52 -07:00
parent 379b904203
commit 15dda79e3b
6 changed files with 60 additions and 58 deletions

View File

@@ -53,11 +53,13 @@ class TestSuperAttention(unittest.TestCase):
@parameterized.expand([[6], [12], [24], [48]])
def test_transformer_encoder(self, input_dim):
output_dim = spaces.Categorical(12, 24, 36)
model = super_core.SuperTransformerEncoderLayer(
input_dim,
output_dim=output_dim,
num_heads=spaces.Categorical(2, 4, 6),
mlp_hidden_multiplier=spaces.Categorical(1, 2, 4),
model = super_core.SuperSequential(
super_core.SuperLinear(input_dim, output_dim),
super_core.SuperTransformerEncoderLayer(
output_dim,
num_heads=spaces.Categorical(2, 4, 6),
mlp_hidden_multiplier=spaces.Categorical(1, 2, 4),
),
)
print(model)
model.apply_verbose(True)