Fix 1-element in norm bug
This commit is contained in:
@@ -40,13 +40,10 @@ def get_model(config: Dict[Text, Any], **kwargs):
|
||||
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
||||
sub_layers, last_dim = [], kwargs["input_dim"]
|
||||
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
||||
sub_layers.extend(
|
||||
[
|
||||
norm_cls(last_dim, elementwise_affine=False),
|
||||
SuperLinear(last_dim, hidden_dim),
|
||||
act_cls(),
|
||||
]
|
||||
)
|
||||
if last_dim > 1:
|
||||
sub_layers.append(norm_cls(last_dim, elementwise_affine=False))
|
||||
sub_layers.append(SuperLinear(last_dim, hidden_dim))
|
||||
sub_layers.append(act_cls())
|
||||
last_dim = hidden_dim
|
||||
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
|
||||
model = SuperSequential(*sub_layers)
|
||||
|
Reference in New Issue
Block a user