Add visualize codes for Q
This commit is contained in:
@@ -48,7 +48,11 @@ def get_model_infos(model, shape):
|
||||
if hasattr(model, "auxiliary_param"):
|
||||
aux_params = count_parameters_in_MB(model.auxiliary_param())
|
||||
print("The auxiliary params of this model is : {:}".format(aux_params))
|
||||
print("We remove the auxiliary params from the total params ({:}) when counting".format(Param))
|
||||
print(
|
||||
"We remove the auxiliary params from the total params ({:}) when counting".format(
|
||||
Param
|
||||
)
|
||||
)
|
||||
Param = Param - aux_params
|
||||
|
||||
# print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
||||
@@ -92,7 +96,9 @@ def pool_flops_counter_hook(pool_module, inputs, output):
|
||||
out_C, output_height, output_width = output.shape[1:]
|
||||
assert out_C == inputs[0].size(1), "{:} vs. {:}".format(out_C, inputs[0].size())
|
||||
|
||||
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
||||
overall_flops = (
|
||||
batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
||||
)
|
||||
pool_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
@@ -104,7 +110,9 @@ def self_calculate_flops_counter_hook(self_module, inputs, output):
|
||||
def fc_flops_counter_hook(fc_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
xin, xout = fc_module.in_features, fc_module.out_features
|
||||
assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(xin, xout)
|
||||
assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(
|
||||
xin, xout
|
||||
)
|
||||
overall_flops = batch_size * xin * xout
|
||||
if fc_module.bias is not None:
|
||||
overall_flops += batch_size * xout
|
||||
@@ -136,7 +144,9 @@ def conv2d_flops_counter_hook(conv_module, inputs, output):
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
|
||||
conv_per_position_flops = (
|
||||
kernel_height * kernel_width * in_channels * out_channels / groups
|
||||
)
|
||||
|
||||
active_elements_count = batch_size * output_height * output_width
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
@@ -184,7 +194,9 @@ def add_flops_counter_hook_function(module):
|
||||
if not hasattr(module, "__flops_handle__"):
|
||||
handle = module.register_forward_hook(fc_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
|
||||
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(
|
||||
module, torch.nn.MaxPool2d
|
||||
):
|
||||
if not hasattr(module, "__flops_handle__"):
|
||||
handle = module.register_forward_hook(pool_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
|
Reference in New Issue
Block a user