This commit is contained in:
D-X-Y
2021-05-26 01:53:44 -07:00
parent 30fb8fad67
commit 299c8a085b
12 changed files with 137 additions and 115 deletions

View File

@@ -39,9 +39,9 @@ def get_model(config: Dict[Text, Any], **kwargs):
norm_cls = super_name2norm[kwargs["norm_cls"]]
sub_layers, last_dim = [], kwargs["input_dim"]
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
sub_layers.append(SuperLinear(last_dim, hidden_dim))
if hidden_dim > 1:
sub_layers.append(norm_cls(hidden_dim, elementwise_affine=False))
sub_layers.append(SuperLinear(last_dim, hidden_dim))
sub_layers.append(act_cls())
last_dim = hidden_dim
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))