Refine lib -> xautodl

This commit is contained in:
D-X-Y
2021-05-19 07:23:50 +00:00
parent 5b9a028e60
commit bd407ac4dc
3 changed files with 4 additions and 4 deletions

View File

@@ -10,9 +10,9 @@ def count_parameters_in_MB(model):
def count_parameters(model_or_parameters, unit="mb"):
if isinstance(model_or_parameters, nn.Module):
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
elif isinstance(models_or_parameters, nn.Parameter):
elif isinstance(model_or_parameters, nn.Parameter):
counts = models_or_parameters.numel()
elif isinstance(models_or_parameters, (list, tuple)):
elif isinstance(model_or_parameters, (list, tuple)):
counts = sum(count_parameters(x, None) for x in models_or_parameters)
else:
counts = sum(np.prod(v.size()) for v in model_or_parameters)