Clean unnecessary files
This commit is contained in:
@@ -1,6 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .meter import AverageMeter
|
||||
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_size2str, convert_secs2time
|
||||
from .data_utils import reader_creator
|
@@ -1,64 +0,0 @@
|
||||
import random, tarfile
|
||||
import numpy, six
|
||||
from six.moves import cPickle as pickle
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
|
||||
def train_cifar_augmentation(image):
|
||||
# flip
|
||||
if random.random() < 0.5: image1 = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
else: image1 = image
|
||||
# random crop
|
||||
image2 = ImageOps.expand(image1, border=4, fill=0)
|
||||
i = random.randint(0, 40 - 32)
|
||||
j = random.randint(0, 40 - 32)
|
||||
image3 = image2.crop((j,i,j+32,i+32))
|
||||
# to numpy
|
||||
image3 = numpy.array(image3) / 255.0
|
||||
mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
|
||||
std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
|
||||
return (image3 - mean) / std
|
||||
|
||||
|
||||
def valid_cifar_augmentation(image):
|
||||
image3 = numpy.array(image) / 255.0
|
||||
mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
|
||||
std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
|
||||
return (image3 - mean) / std
|
||||
|
||||
|
||||
def reader_creator(filename, sub_name, is_train, cycle=False):
|
||||
def read_batch(batch):
|
||||
data = batch[six.b('data')]
|
||||
labels = batch.get(
|
||||
six.b('labels'), batch.get(six.b('fine_labels'), None))
|
||||
assert labels is not None
|
||||
for sample, label in six.moves.zip(data, labels):
|
||||
sample = sample.reshape(3, 32, 32)
|
||||
sample = sample.transpose((1, 2, 0))
|
||||
image = Image.fromarray(sample)
|
||||
if is_train:
|
||||
ximage = train_cifar_augmentation(image)
|
||||
else:
|
||||
ximage = valid_cifar_augmentation(image)
|
||||
ximage = ximage.transpose((2, 0, 1))
|
||||
yield ximage.astype(numpy.float32), int(label)
|
||||
|
||||
def reader():
|
||||
with tarfile.open(filename, mode='r') as f:
|
||||
names = (each_item.name for each_item in f
|
||||
if sub_name in each_item.name)
|
||||
|
||||
while True:
|
||||
for name in names:
|
||||
if six.PY2:
|
||||
batch = pickle.load(f.extractfile(name))
|
||||
else:
|
||||
batch = pickle.load(
|
||||
f.extractfile(name), encoding='bytes')
|
||||
for item in read_batch(batch):
|
||||
yield item
|
||||
if not cycle:
|
||||
break
|
||||
|
||||
return reader
|
@@ -1,23 +0,0 @@
|
||||
import time, sys
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
|
@@ -1,46 +0,0 @@
|
||||
import time, sys
|
||||
import numpy as np
|
||||
|
||||
def time_for_file():
|
||||
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
|
||||
return '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||
|
||||
def time_string():
|
||||
ISOTIMEFORMAT='%Y-%m-%d %X'
|
||||
string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||
return string
|
||||
|
||||
def time_string_short():
|
||||
ISOTIMEFORMAT='%Y%m%d'
|
||||
string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||
return string
|
||||
|
||||
def time_print(string, is_print=True):
|
||||
if (is_print):
|
||||
print('{} : {}'.format(time_string(), string))
|
||||
|
||||
def convert_size2str(torch_size):
|
||||
dims = len(torch_size)
|
||||
string = '['
|
||||
for idim in range(dims):
|
||||
string = string + ' {}'.format(torch_size[idim])
|
||||
return string + ']'
|
||||
|
||||
def convert_secs2time(epoch_time, return_str=False):
|
||||
need_hour = int(epoch_time / 3600)
|
||||
need_mins = int((epoch_time - 3600*need_hour) / 60)
|
||||
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
|
||||
if return_str:
|
||||
str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
|
||||
return str
|
||||
else:
|
||||
return need_hour, need_mins, need_secs
|
||||
|
||||
def print_log(print_string, log):
|
||||
#if isinstance(log, Logger): log.log('{:}'.format(print_string))
|
||||
if hasattr(log, 'log'): log.log('{:}'.format(print_string))
|
||||
else:
|
||||
print("{:}".format(print_string))
|
||||
if log is not None:
|
||||
log.write('{:}\n'.format(print_string))
|
||||
log.flush()
|
Reference in New Issue
Block a user