Add visualize codes for Q

This commit is contained in:
D-X-Y
2021-04-11 21:45:20 +08:00
parent e777f38233
commit 0e2dd13762
16 changed files with 570 additions and 125 deletions

View File

@@ -48,7 +48,11 @@ def get_model_infos(model, shape):
if hasattr(model, "auxiliary_param"):
aux_params = count_parameters_in_MB(model.auxiliary_param())
print("The auxiliary params of this model is : {:}".format(aux_params))
print("We remove the auxiliary params from the total params ({:}) when counting".format(Param))
print(
"We remove the auxiliary params from the total params ({:}) when counting".format(
Param
)
)
Param = Param - aux_params
# print_log('FLOPs : {:} MB'.format(FLOPs), log)
@@ -92,7 +96,9 @@ def pool_flops_counter_hook(pool_module, inputs, output):
out_C, output_height, output_width = output.shape[1:]
assert out_C == inputs[0].size(1), "{:} vs. {:}".format(out_C, inputs[0].size())
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
overall_flops = (
batch_size * out_C * output_height * output_width * kernel_size * kernel_size
)
pool_module.__flops__ += overall_flops
@@ -104,7 +110,9 @@ def self_calculate_flops_counter_hook(self_module, inputs, output):
def fc_flops_counter_hook(fc_module, inputs, output):
batch_size = inputs[0].size(0)
xin, xout = fc_module.in_features, fc_module.out_features
assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(xin, xout)
assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(
xin, xout
)
overall_flops = batch_size * xin * xout
if fc_module.bias is not None:
overall_flops += batch_size * xout
@@ -136,7 +144,9 @@ def conv2d_flops_counter_hook(conv_module, inputs, output):
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
conv_per_position_flops = (
kernel_height * kernel_width * in_channels * out_channels / groups
)
active_elements_count = batch_size * output_height * output_width
overall_flops = conv_per_position_flops * active_elements_count
@@ -184,7 +194,9 @@ def add_flops_counter_hook_function(module):
if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(fc_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(
module, torch.nn.MaxPool2d
):
if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(pool_flops_counter_hook)
module.__flops_handle__ = handle