Add more algorithms

This commit is contained in:
D-X-Y
2019-09-28 18:24:47 +10:00
parent bfd6b648fd
commit cfb462e463
286 changed files with 10557 additions and 122955 deletions

View File

@@ -1,16 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .utils import AverageMeter, RecorderMeter, convert_secs2time
from .utils import time_file_str, time_string
from .utils import test_imagenet_data
from .utils import print_log
from .evaluation_utils import obtain_accuracy
#from .draw_pts import draw_points
from .gpu_manager import GPUManager
from .save_meta import Save_Meta
from .model_utils import count_parameters_in_MB
from .model_utils import Cutout
from .flop_benchmark import print_FLOPs
from .gpu_manager import GPUManager
from .flop_benchmark import get_model_infos

View File

@@ -1,41 +0,0 @@
import os, sys, time
import numpy as np
import matplotlib
import random
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
def draw_points(points, labels, save_path):
title = 'the visualized features'
dpi = 100
width, height = 1000, 1000
legend_fontsize = 10
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
classes = np.unique(labels).tolist()
colors = cm.rainbow(np.linspace(0, 1, len(classes)))
legends = []
legendnames = []
for cls, c in zip(classes, colors):
indexes = labels == cls
ptss = points[indexes, :]
x = ptss[:,0]
y = ptss[:,1]
if cls % 2 == 0: marker = 'x'
else: marker = 'o'
legend = plt.scatter(x, y, color=c, s=1, marker=marker)
legendname = '{:02d}'.format(cls+1)
legends.append( legend )
legendnames.append( legendname )
plt.legend(legends, legendnames, scatterpoints=1, ncol=5, fontsize=8)
if save_path is not None:
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
print ('---- save figure {} into {}'.format(title, save_path))
plt.close(fig)

View File

@@ -3,21 +3,44 @@
##################################################
# modified from https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py
import copy, torch
import torch.nn as nn
import numpy as np
def print_FLOPs(model, shape, logs):
print_log, log = logs
model = copy.deepcopy( model )
def count_parameters_in_MB(model):
if isinstance(model, nn.Module):
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
else:
return np.sum(np.prod(v.size()) for v in model)/1e6
def get_model_infos(model, shape):
#model = copy.deepcopy( model )
model = add_flops_counting_methods(model)
model = model.cuda()
#model = model.cuda()
model.eval()
cache_inputs = torch.zeros(*shape).cuda()
#cache_inputs = torch.zeros(*shape).cuda()
#cache_inputs = torch.zeros(*shape)
cache_inputs = torch.rand(*shape)
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
_ = model(cache_inputs)
with torch.no_grad():
_____ = model(cache_inputs)
FLOPs = compute_average_flops_cost( model ) / 1e6
print_log('FLOPs : {:} MB'.format(FLOPs), log)
Param = count_parameters_in_MB(model)
if hasattr(model, 'auxiliary_param'):
aux_params = count_parameters_in_MB(model.auxiliary_param())
print ('The auxiliary params of this model is : {:}'.format(aux_params))
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
Param = Param - aux_params
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
torch.cuda.empty_cache()
model.apply( remove_hook_function )
return FLOPs, Param
# ---- Public functions
@@ -37,8 +60,11 @@ def compute_average_flops_cost(model):
"""
batches_count = model.__batch_counter__
flops_sum = 0
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
for module in model.modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or hasattr(module, 'calculate_flop_self'):
flops_sum += module.__flops__
return flops_sum / batches_count
@@ -54,6 +80,11 @@ def pool_flops_counter_hook(pool_module, inputs, output):
pool_module.__flops__ += overall_flops
def self_calculate_flops_counter_hook(self_module, inputs, output):
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
self_module.__flops__ += overall_flops
def fc_flops_counter_hook(fc_module, inputs, output):
batch_size = inputs[0].size(0)
xin, xout = fc_module.in_features, fc_module.out_features
@@ -64,7 +95,24 @@ def fc_flops_counter_hook(fc_module, inputs, output):
fc_module.__flops__ += overall_flops
def conv_flops_counter_hook(conv_module, inputs, output):
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
batch_size = inputs[0].size(0)
outL = outputs.shape[-1]
[kernel] = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel * in_channels * out_channels / groups
active_elements_count = batch_size * outL
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def conv2d_flops_counter_hook(conv_module, inputs, output):
batch_size = inputs[0].size(0)
output_height, output_width = output.shape[2:]
@@ -97,14 +145,20 @@ def add_batch_counter_hook_function(module):
def add_flops_counter_variable_or_reset(module):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
or isinstance(module, torch.nn.Conv1d) \
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
or hasattr(module, 'calculate_flop_self'):
module.__flops__ = 0
def add_flops_counter_hook_function(module):
if isinstance(module, torch.nn.Conv2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv_flops_counter_hook)
handle = module.register_forward_hook(conv2d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Conv1d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv1d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Linear):
if not hasattr(module, '__flops_handle__'):
@@ -114,3 +168,18 @@ def add_flops_counter_hook_function(module):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(pool_flops_counter_hook)
module.__flops_handle__ = handle
elif hasattr(module, 'calculate_flop_self'): # self-defined module
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
module.__flops_handle__ = handle
def remove_hook_function(module):
hookers = ['__batch_counter_handle__', '__flops_handle__']
for hooker in hookers:
if hasattr(module, hooker):
handle = getattr(module, hooker)
handle.remove()
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
for ckey in keys:
if hasattr(module, ckey): delattr(module, ckey)

View File

@@ -1,35 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
def count_parameters_in_MB(model):
if isinstance(model, nn.Module):
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
else:
return np.sum(np.prod(v.size()) for v in model)/1e6
class Cutout(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img

View File

@@ -1,53 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import os, sys
import os.path as osp
import numpy as np
def tensor2np(x):
if isinstance(x, np.ndarray): return x
if x.is_cuda: x = x.cpu()
return x.numpy()
class Save_Meta():
def __init__(self):
self.reset()
def __repr__(self):
return ('{name}'.format(name=self.__class__.__name__)+'(number of data = {})'.format(len(self)))
def reset(self):
self.predictions = []
self.groundtruth = []
def __len__(self):
return len(self.predictions)
def append(self, _pred, _ground):
_pred, _ground = tensor2np(_pred), tensor2np(_ground)
assert _ground.shape[0] == _pred.shape[0] and len(_pred.shape) == 2 and len(_ground.shape) == 1, 'The shapes are wrong : {} & {}'.format(_pred.shape, _ground.shape)
self.predictions.append(_pred)
self.groundtruth.append(_ground)
def save(self, save_dir, filename, test=True):
meta = {'predictions': self.predictions,
'groundtruth': self.groundtruth}
filename = osp.join(save_dir, filename)
torch.save(meta, filename)
if test:
predictions = np.concatenate(self.predictions)
groundtruth = np.concatenate(self.groundtruth)
predictions = np.argmax(predictions, axis=1)
accuracy = np.sum(groundtruth==predictions) * 100.0 / predictions.size
else:
accuracy = None
print ('save save_meta into {} with accuracy = {}'.format(filename, accuracy))
def load(self, filename):
assert os.path.isfile(filename), '{} is not a file'.format(filename)
checkpoint = torch.load(filename)
self.predictions = checkpoint['predictions']
self.groundtruth = checkpoint['groundtruth']

View File

@@ -1,140 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time
import numpy as np
import random
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class RecorderMeter(object):
"""Computes and stores the minimum loss value and its epoch index"""
def __init__(self, total_epoch):
self.reset(total_epoch)
def reset(self, total_epoch):
assert total_epoch > 0
self.total_epoch = total_epoch
self.current_epoch = 0
self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_losses = self.epoch_losses - 1
self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_accuracy= self.epoch_accuracy
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
self.epoch_losses [idx, 0] = train_loss
self.epoch_losses [idx, 1] = val_loss
self.epoch_accuracy[idx, 0] = train_acc
self.epoch_accuracy[idx, 1] = val_acc
self.current_epoch = idx + 1
return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
def max_accuracy(self, istrain):
if self.current_epoch <= 0: return 0
if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
else: return self.epoch_accuracy[:self.current_epoch, 1].max()
def plot_curve(self, save_path):
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
title = 'the accuracy/loss curve of train/val'
dpi = 100
width, height = 1600, 1000
legend_fontsize = 10
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
y_axis = np.zeros(self.total_epoch)
plt.xlim(0, self.total_epoch)
plt.ylim(0, 100)
interval_y = 5
interval_x = 5
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
plt.grid()
plt.title(title, fontsize=20)
plt.xlabel('the training epoch', fontsize=16)
plt.ylabel('accuracy', fontsize=16)
y_axis[:] = self.epoch_accuracy[:, 0]
plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_accuracy[:, 1]
plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_losses[:, 0]
plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_losses[:, 1]
plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
if save_path is not None:
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
print ('---- save figure {} into {}'.format(title, save_path))
plt.close(fig)
def print_log(print_string, log):
print ("{:}".format(print_string))
if log is not None:
log.write('{}\n'.format(print_string))
log.flush()
def time_file_str():
ISOTIMEFORMAT='%Y-%m-%d'
string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string + '-{}'.format(random.randint(1, 10000))
def time_string():
ISOTIMEFORMAT='%Y-%m-%d-%X'
string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def convert_secs2time(epoch_time, return_str=False):
need_hour = int(epoch_time / 3600)
need_mins = int((epoch_time - 3600*need_hour) / 60)
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
if return_str == False:
return need_hour, need_mins, need_secs
else:
return '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
def test_imagenet_data(imagenet):
total_length = len(imagenet)
assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length)
map_id = {}
for index in range(total_length):
path, target = imagenet.imgs[index]
folder, image_name = os.path.split(path)
_, folder = os.path.split(folder)
if folder not in map_id:
map_id[folder] = target
else:
assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target)
assert image_name.find(folder) == 0, '{} is wrong.'.format(path)
print ('Check ImageNet Dataset OK')