Skip to content

Commit

Permalink
Refactor ImageDataGenerator, add directory support
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jun 5, 2016
1 parent e5b99c7 commit 0e18e34
Showing 1 changed file with 230 additions and 78 deletions.
308 changes: 230 additions & 78 deletions keras/preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
new preprocessing methods, etc...
'''
from __future__ import absolute_import
from __future__ import print_function

import numpy as np
import re
Expand Down Expand Up @@ -135,7 +136,6 @@ def array_to_img(x, dim_ordering=K.image_dim_ordering(), scale=True):
raise Exception('Unsupported channel number: ', x.shape[2])


# only used by tests/keras/preprocessing/test_image.py to convert PIL.Image to numpy array
def img_to_array(img, dim_ordering=K.image_dim_ordering()):
if dim_ordering not in ['th', 'tf']:
raise Exception('Unknown dim_ordering: ', dim_ordering)
Expand All @@ -154,13 +154,15 @@ def img_to_array(img, dim_ordering=K.image_dim_ordering()):
return x


def load_img(path, grayscale=False):
def load_img(path, grayscale=False, target_size=None):
from PIL import Image
img = Image.open(path)
if grayscale:
img = img.convert('L')
else: # Ensure 3 channel even when loaded image is grayscale
img = img.convert('RGB')
if target_size:
img = img.resize(target_size)
return img


Expand Down Expand Up @@ -216,23 +218,24 @@ def __init__(self,
cval=0.,
horizontal_flip=False,
vertical_flip=False,
rescale=None,
dim_ordering=K.image_dim_ordering()):
self.__dict__.update(locals())
self.mean = None
self.std = None
self.principal_components = None
self.lock = threading.Lock()
self.rescale = rescale

if dim_ordering not in {'tf', 'th'}:
raise Exception('dim_ordering should be "tf" (channel after row and '
'column) or "th" (channel before row and column). '
'Received arg: ', dim_ordering)
self.dim_ordering = dim_ordering
if dim_ordering == "th":
if dim_ordering == 'th':
self.channel_index = 1
self.row_index = 2
self.col_index = 3
if dim_ordering == "tf":
if dim_ordering == 'tf':
self.channel_index = 3
self.row_index = 1
self.col_index = 2
Expand All @@ -246,82 +249,30 @@ def __init__(self,
'a tuple or list of two floats. '
'Received arg: ', zoom_range)

self.batch_index = 0
self.total_batches_seen = 0

def reset(self):
self.batch_index = 0

def _flow_index(self, N, batch_size=32, shuffle=False, seed=None):
# ensure self.batch_index is 0
self.reset()

while 1:
if self.batch_index == 0:
index_array = np.arange(N)
if shuffle:
if seed is not None:
np.random.seed(seed + self.total_batches_seen)
index_array = np.random.permutation(N)

current_index = (self.batch_index * batch_size) % N
if N >= current_index + batch_size:
current_batch_size = batch_size
self.batch_index += 1
else:
current_batch_size = N - current_index
self.batch_index = 0
self.total_batches_seen += 1
yield (index_array[current_index: current_index + current_batch_size],
current_index, current_batch_size)

def flow(self, X, y, batch_size=32, shuffle=False, seed=None,
def flow(self, X, y, batch_size=32, shuffle=True, seed=None,
save_to_dir=None, save_prefix='', save_format='jpeg'):
assert len(X) == len(y)
self.X = X
self.y = y
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format
self.reset()
self.flow_generator = self._flow_index(X.shape[0], batch_size,
shuffle, seed)
return self

def __iter__(self):
# needed if we want to do something like:
# for x, y in data_gen.flow(...):
return self

def next(self):
# for python 2.x.
# Keeps under lock only the mechanism which advances
# the indexing of each batch
# see # http://anandology.com/blog/using-iterators-and-generators/
with self.lock:
index_array, current_index, current_batch_size = next(self.flow_generator)
# The transformation of images is not under thread lock so it can be done in parallel
bX = np.zeros(tuple([current_batch_size] + list(self.X.shape)[1:]))
for i, j in enumerate(index_array):
x = self.X[j]
x = self.random_transform(x.astype('float32'))
x = self.standardize(x)
bX[i] = x
if self.save_to_dir:
for i in range(current_batch_size):
img = array_to_img(bX[i], self.dim_ordering, scale=True)
fname = '{prefix}_{index}.{format}'.format(prefix=self.save_prefix,
index=current_index + i,
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
bY = self.y[index_array]
return bX, bY

def __next__(self):
# for python 3.x.
return self.next()
return NumpyArrayIterator(
X, y, self,
batch_size=batch_size, shuffle=shuffle, seed=seed,
dim_ordering=self.dim_ordering,
save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format)

def flow_from_directory(self, directory,
target_size=(256, 256), color_mode='rgb',
classes=None, class_mode='categorical',
batch_size=32, shuffle=True, seed=None,
save_to_dir=None, save_prefix='', save_format='jpeg'):
return DirectoryIterator(
directory, self,
target_size=target_size, color_mode=color_mode,
classes=classes, class_mode=class_mode,
dim_ordering=self.dim_ordering,
batch_size=batch_size, shuffle=shuffle, seed=seed,
save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format)

def standardize(self, x):
if self.rescale:
x *= self.rescale
# x is a single image, so it doesn't have image number at index 0
img_channel_index = self.channel_index - 1
if self.samplewise_center:
Expand Down Expand Up @@ -441,3 +392,204 @@ def fit(self, X,
sigma = np.dot(flatX.T, flatX) / flatX.shape[1]
U, S, V = linalg.svd(sigma)
self.principal_components = np.dot(np.dot(U, np.diag(1. / np.sqrt(S + 10e-7))), U.T)


class Iterator(object):

def __init__(self, N, batch_size, shuffle, seed):
self.N = N
self.batch_size = batch_size
self.shuffle = shuffle
self.batch_index = 0
self.total_batches_seen = 0
self.lock = threading.Lock()
self.index_generator = self._flow_index(N, batch_size, shuffle, seed)

def reset(self):
self.batch_index = 0

def _flow_index(self, N, batch_size=32, shuffle=False, seed=None):
# ensure self.batch_index is 0
self.reset()
while 1:
if self.batch_index == 0:
index_array = np.arange(N)
if shuffle:
if seed is not None:
np.random.seed(seed + self.total_batches_seen)
index_array = np.random.permutation(N)

current_index = (self.batch_index * batch_size) % N
if N >= current_index + batch_size:
current_batch_size = batch_size
self.batch_index += 1
else:
current_batch_size = N - current_index
self.batch_index = 0
self.total_batches_seen += 1
yield (index_array[current_index: current_index + current_batch_size],
current_index, current_batch_size)

def __iter__(self):
# needed if we want to do something like:
# for x, y in data_gen.flow(...):
return self


class NumpyArrayIterator(Iterator):

def __init__(self, X, y, image_data_generator,
batch_size=32, shuffle=False, seed=None,
dim_ordering=K.image_dim_ordering(),
save_to_dir=None, save_prefix='', save_format='jpeg'):
if len(X) != len(y):
raise Exception('X (images tensor) and y (labels) '
'should have the same length. '
'Found: X.shape = %s, y.shape = %s' % (np.asarray(X).shape, np.asarray(y).shape))
self.X = X
self.y = y
self.image_data_generator = image_data_generator
self.dim_ordering
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format
super(NumpyArrayIterator, self).__init__(X.shape[0], batch_size, shuffle, seed)

def next(self):
# for python 2.x.
# Keeps under lock only the mechanism which advances
# the indexing of each batch
# see http://anandology.com/blog/using-iterators-and-generators/
with self.lock:
index_array, current_index, current_batch_size = next(self.index_generator)
# The transformation of images is not under thread lock so it can be done in parallel
batch_x = np.zeros(tuple([current_batch_size] + list(self.X.shape)[1:]))
for i, j in enumerate(index_array):
x = self.X[j]
x = self.image_data_generator.random_transform(x.astype('float32'))
x = self.image_data_generator.standardize(x)
batch_x[i] = x
if self.save_to_dir:
for i in range(current_batch_size):
img = array_to_img(batch_x[i], self.dim_ordering, scale=True)
fname = '{prefix}_{index}.{format}'.format(prefix=self.save_prefix,
index=current_index + i,
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
batch_y = self.y[index_array]
return batch_x, batch_y


class DirectoryIterator(Iterator):

def __init__(self, directory, image_data_generator,
target_size=(256, 256), color_mode='rgb',
dim_ordering=K.image_dim_ordering,
classes=None, class_mode='categorical',
batch_size=32, shuffle=True, seed=None,
save_to_dir=None, save_prefix='', save_format='jpeg'):
self.directory = directory
self.image_data_generator = image_data_generator
self.target_size = tuple(target_size)
if color_mode not in {'rgb', 'grayscale'}:
raise ValueError('Invalid color mode:', color_mode,
'; expected "rgb" or "grayscale".')
self.color_mode = color_mode
self.dim_ordering = dim_ordering
if self.color_mode == 'rgb':
if self.dim_ordering == 'tf':
self.image_shape = self.target_size + (3,)
else:
self.image_shape = (3,) + self.target_size
else:
if self.dim_ordering == 'tf':
self.image_shape = self.target_size + (1,)
else:
self.image_shape = (1,) + self.target_size
self.classes = classes
if class_mode not in {'categorical', 'binary', 'sparse', None}:
raise ValueError('Invalid class_mode:', class_mode,
'; expected one of "categorical", '
'"binary", "sparse", or None.')
self.class_mode = class_mode
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format

white_list_formats = {'png', 'jpg', 'jpeg', 'bmp'}

# first, count the number of samples and classes
self.nb_sample = 0

if not classes:
classes = []
for subdir in os.listdir(directory):
if os.path.isdir(os.path.join(directory, subdir)):
classes.append(subdir)
self.nb_class = len(classes)
self.class_indices = dict(zip(classes, range(len(classes))))

for subdir in classes:
subpath = os.path.join(directory, subdir)
for fname in os.listdir(subpath):
is_valid = False
for extension in white_list_formats:
if fname.endswith('.' + extension):
is_valid = True
break
if is_valid:
self.nb_sample += 1
print('Found %d images belonging to %d classes.' % (self.nb_sample, self.nb_class))

# second, build an index of the images in the different class subfolders
self.filenames = []
self.classes = np.zeros((self.nb_sample,), dtype='int32')
i = 0
for subdir in classes:
subpath = os.path.join(directory, subdir)
for fname in os.listdir(subpath):
is_valid = False
for extension in white_list_formats:
if fname.endswith('.' + extension):
is_valid = True
break
if is_valid:
self.classes[i] = self.class_indices[subdir]
self.filenames.append(os.path.join(subdir, fname))
i += 1
super(DirectoryIterator, self).__init__(self.nb_sample, batch_size, shuffle, seed)

def next(self):
with self.lock:
index_array, current_index, current_batch_size = next(self.index_generator)
# The transformation of images is not under thread lock so it can be done in parallel
batch_x = np.zeros((current_batch_size,) + self.image_shape)
grayscale = self.color_mode == 'grayscale'
# build batch of image data
for i, j in enumerate(index_array):
fname = self.filenames[j]
img = load_img(os.path.join(self.directory, fname), grayscale=grayscale, target_size=self.target_size)
x = img_to_array(img, dim_ordering=self.dim_ordering)
x = self.image_data_generator.random_transform(x)
x = self.image_data_generator.standardize(x)
batch_x[i] = x
# optionally save augmented images to disk for debugging purposes
if self.save_to_dir:
for i in range(current_batch_size):
img = array_to_img(batch_x[i], self.dim_ordering, scale=True)
fname = '{prefix}_{index}.{format}'.format(prefix=self.save_prefix,
index=current_index + i,
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
# build batch of labels
if self.class_mode == 'sparse':
batch_y = self.classes[index_array]
elif self.class_mode == 'binary':
batch_y = self.classes[index_array].astype('float32')
elif self.class_mode == 'categorical':
batch_y = np.zeros((len(batch_x), self.nb_class), dtype='float32')
for i, label in enumerate(self.classes[index_array]):
batch_y[i, label] = 1.
else:
return batch_x
return batch_x, batch_y

0 comments on commit 0e18e34

Please sign in to comment.