Fix small bugs

This commit is contained in:
D-X-Y
2021-08-14 16:01:07 -07:00
parent 58733c18be
commit d04edcd211
12 changed files with 95 additions and 18 deletions

View File

@@ -14,20 +14,24 @@ def count_parameters(model_or_parameters, unit="mb", deprecated=False):
if isinstance(model_or_parameters, nn.Module):
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
elif isinstance(model_or_parameters, nn.Parameter):
counts = models_or_parameters.numel()
counts = model_or_parameters.numel()
elif isinstance(model_or_parameters, (list, tuple)):
counts = sum(
count_parameters(x, None, deprecated) for x in models_or_parameters
count_parameters(x, None, deprecated) for x in model_or_parameters
)
else:
counts = sum(np.prod(v.size()) for v in model_or_parameters)
if unit.lower() == "kb" or unit.lower() == "k":
if not isinstance(unit, str) and unit is not None:
raise ValueError("Unknow type of unit: {:}".format(unit))
elif unit is None:
counts = counts
elif unit.lower() == "kb" or unit.lower() == "k":
counts /= 1e3 if deprecated else 2 ** 10 # changed from 1e3 to 2^10
elif unit.lower() == "mb" or unit.lower() == "m":
counts /= 1e6 if deprecated else 2 ** 20 # changed from 1e6 to 2^20
elif unit.lower() == "gb" or unit.lower() == "g":
counts /= 1e9 if deprecated else 2 ** 30 # changed from 1e9 to 2^30
elif unit is not None:
else:
raise ValueError("Unknow unit: {:}".format(unit))
return counts