Add more algorithms

This commit is contained in:
D-X-Y
2019-09-28 18:24:47 +10:00
parent bfd6b648fd
commit cfb462e463
286 changed files with 10557 additions and 122955 deletions

View File

@@ -0,0 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .meter import AverageMeter
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_size2str, convert_secs2time
from .data_utils import reader_creator

View File

@@ -0,0 +1,64 @@
import random, tarfile
import numpy, six
from six.moves import cPickle as pickle
from PIL import Image, ImageOps
def train_cifar_augmentation(image):
# flip
if random.random() < 0.5: image1 = image.transpose(Image.FLIP_LEFT_RIGHT)
else: image1 = image
# random crop
image2 = ImageOps.expand(image1, border=4, fill=0)
i = random.randint(0, 40 - 32)
j = random.randint(0, 40 - 32)
image3 = image2.crop((j,i,j+32,i+32))
# to numpy
image3 = numpy.array(image3) / 255.0
mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
return (image3 - mean) / std
def valid_cifar_augmentation(image):
image3 = numpy.array(image) / 255.0
mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
return (image3 - mean) / std
def reader_creator(filename, sub_name, is_train, cycle=False):
def read_batch(batch):
data = batch[six.b('data')]
labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in six.moves.zip(data, labels):
sample = sample.reshape(3, 32, 32)
sample = sample.transpose((1, 2, 0))
image = Image.fromarray(sample)
if is_train:
ximage = train_cifar_augmentation(image)
else:
ximage = valid_cifar_augmentation(image)
ximage = ximage.transpose((2, 0, 1))
yield ximage.astype(numpy.float32), int(label)
def reader():
with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f
if sub_name in each_item.name)
while True:
for name in names:
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(
f.extractfile(name), encoding='bytes')
for item in read_batch(batch):
yield item
if not cycle:
break
return reader

View File

@@ -0,0 +1,26 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import time, sys
import numpy as np
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __repr__(self):
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))

View File

@@ -0,0 +1,52 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import time, sys
import numpy as np
def time_for_file():
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
return '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
def time_string():
ISOTIMEFORMAT='%Y-%m-%d %X'
string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_string_short():
ISOTIMEFORMAT='%Y%m%d'
string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_print(string, is_print=True):
if (is_print):
print('{} : {}'.format(time_string(), string))
def convert_size2str(torch_size):
dims = len(torch_size)
string = '['
for idim in range(dims):
string = string + ' {}'.format(torch_size[idim])
return string + ']'
def convert_secs2time(epoch_time, return_str=False):
need_hour = int(epoch_time / 3600)
need_mins = int((epoch_time - 3600*need_hour) / 60)
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
if return_str:
str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
return str
else:
return need_hour, need_mins, need_secs
def print_log(print_string, log):
#if isinstance(log, Logger): log.log('{:}'.format(print_string))
if hasattr(log, 'log'): log.log('{:}'.format(print_string))
else:
print("{:}".format(print_string))
if log is not None:
log.write('{:}\n'.format(print_string))
log.flush()