Refine lib -> xautodl
This commit is contained in:
@@ -10,9 +10,9 @@ def count_parameters_in_MB(model):
|
||||
def count_parameters(model_or_parameters, unit="mb"):
|
||||
if isinstance(model_or_parameters, nn.Module):
|
||||
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
|
||||
elif isinstance(models_or_parameters, nn.Parameter):
|
||||
elif isinstance(model_or_parameters, nn.Parameter):
|
||||
counts = models_or_parameters.numel()
|
||||
elif isinstance(models_or_parameters, (list, tuple)):
|
||||
elif isinstance(model_or_parameters, (list, tuple)):
|
||||
counts = sum(count_parameters(x, None) for x in models_or_parameters)
|
||||
else:
|
||||
counts = sum(np.prod(v.size()) for v in model_or_parameters)
|
||||
|
Reference in New Issue
Block a user