Add more algorithms
This commit is contained in:
@@ -3,21 +3,44 @@
|
||||
##################################################
|
||||
# modified from https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py
|
||||
import copy, torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
def print_FLOPs(model, shape, logs):
|
||||
print_log, log = logs
|
||||
model = copy.deepcopy( model )
|
||||
|
||||
def count_parameters_in_MB(model):
|
||||
if isinstance(model, nn.Module):
|
||||
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
|
||||
else:
|
||||
return np.sum(np.prod(v.size()) for v in model)/1e6
|
||||
|
||||
|
||||
def get_model_infos(model, shape):
|
||||
#model = copy.deepcopy( model )
|
||||
|
||||
model = add_flops_counting_methods(model)
|
||||
model = model.cuda()
|
||||
#model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
cache_inputs = torch.zeros(*shape).cuda()
|
||||
#cache_inputs = torch.zeros(*shape).cuda()
|
||||
#cache_inputs = torch.zeros(*shape)
|
||||
cache_inputs = torch.rand(*shape)
|
||||
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
|
||||
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
|
||||
_ = model(cache_inputs)
|
||||
with torch.no_grad():
|
||||
_____ = model(cache_inputs)
|
||||
FLOPs = compute_average_flops_cost( model ) / 1e6
|
||||
print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
||||
Param = count_parameters_in_MB(model)
|
||||
|
||||
if hasattr(model, 'auxiliary_param'):
|
||||
aux_params = count_parameters_in_MB(model.auxiliary_param())
|
||||
print ('The auxiliary params of this model is : {:}'.format(aux_params))
|
||||
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
|
||||
Param = Param - aux_params
|
||||
|
||||
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
||||
torch.cuda.empty_cache()
|
||||
model.apply( remove_hook_function )
|
||||
return FLOPs, Param
|
||||
|
||||
|
||||
# ---- Public functions
|
||||
@@ -37,8 +60,11 @@ def compute_average_flops_cost(model):
|
||||
"""
|
||||
batches_count = model.__batch_counter__
|
||||
flops_sum = 0
|
||||
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
|
||||
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
||||
or isinstance(module, torch.nn.Conv1d) \
|
||||
or hasattr(module, 'calculate_flop_self'):
|
||||
flops_sum += module.__flops__
|
||||
return flops_sum / batches_count
|
||||
|
||||
@@ -54,6 +80,11 @@ def pool_flops_counter_hook(pool_module, inputs, output):
|
||||
pool_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def self_calculate_flops_counter_hook(self_module, inputs, output):
|
||||
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
|
||||
self_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def fc_flops_counter_hook(fc_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
xin, xout = fc_module.in_features, fc_module.out_features
|
||||
@@ -64,7 +95,24 @@ def fc_flops_counter_hook(fc_module, inputs, output):
|
||||
fc_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv_flops_counter_hook(conv_module, inputs, output):
|
||||
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
|
||||
batch_size = inputs[0].size(0)
|
||||
outL = outputs.shape[-1]
|
||||
[kernel] = conv_module.kernel_size
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = kernel * in_channels * out_channels / groups
|
||||
|
||||
active_elements_count = batch_size * outL
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
|
||||
if conv_module.bias is not None:
|
||||
overall_flops += out_channels * active_elements_count
|
||||
conv_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv2d_flops_counter_hook(conv_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
output_height, output_width = output.shape[2:]
|
||||
|
||||
@@ -97,14 +145,20 @@ def add_batch_counter_hook_function(module):
|
||||
|
||||
def add_flops_counter_variable_or_reset(module):
|
||||
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
||||
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
|
||||
or isinstance(module, torch.nn.Conv1d) \
|
||||
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
||||
or hasattr(module, 'calculate_flop_self'):
|
||||
module.__flops__ = 0
|
||||
|
||||
|
||||
def add_flops_counter_hook_function(module):
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(conv_flops_counter_hook)
|
||||
handle = module.register_forward_hook(conv2d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Conv1d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(conv1d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
@@ -114,3 +168,18 @@ def add_flops_counter_hook_function(module):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(pool_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif hasattr(module, 'calculate_flop_self'): # self-defined module
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
|
||||
|
||||
def remove_hook_function(module):
|
||||
hookers = ['__batch_counter_handle__', '__flops_handle__']
|
||||
for hooker in hookers:
|
||||
if hasattr(module, hooker):
|
||||
handle = getattr(module, hooker)
|
||||
handle.remove()
|
||||
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
|
||||
for ckey in keys:
|
||||
if hasattr(module, ckey): delattr(module, ckey)
|
||||
|
Reference in New Issue
Block a user