Skip to content

Commit

Permalink
Improve type-check for Sequence (keras-team#11468)
Browse files Browse the repository at this point in the history
  • Loading branch information
Frédéric Branchaud-Charron authored and fchollet committed Oct 24, 2018
1 parent 36b9e4c commit d6b5c5e
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 21 deletions.
40 changes: 21 additions & 19 deletions keras/engine/training_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import numpy as np

from .training_utils import is_sequence
from .training_utils import iter_sequence_infinite
from .. import backend as K
from ..utils.data_utils import Sequence
Expand Down Expand Up @@ -40,15 +41,15 @@ def fit_generator(model,
if do_validation:
model._make_test_function()

is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
if steps_per_epoch is None:
if is_sequence:
if use_sequence_api:
steps_per_epoch = len(generator)
else:
raise ValueError('`steps_per_epoch=None` is only valid for a'
Expand All @@ -59,10 +60,11 @@ def fit_generator(model,

# python 2 has 'next', 3 has '__next__'
# avoid any explicit version checks
val_use_sequence_api = is_sequence(validation_data)
val_gen = (hasattr(validation_data, 'next') or
hasattr(validation_data, '__next__') or
isinstance(validation_data, Sequence))
if (val_gen and not isinstance(validation_data, Sequence) and
val_use_sequence_api)
if (val_gen and not val_use_sequence_api and
not validation_steps):
raise ValueError('`validation_steps=None` is only valid for a'
' generator based on the `keras.utils.Sequence`'
Expand Down Expand Up @@ -108,7 +110,7 @@ def fit_generator(model,
if val_gen and workers > 0:
# Create an Enqueuer that can be reused
val_data = validation_data
if isinstance(val_data, Sequence):
if is_sequence(val_data):
val_enqueuer = OrderedEnqueuer(
val_data,
use_multiprocessing=use_multiprocessing)
Expand All @@ -122,7 +124,7 @@ def fit_generator(model,
val_enqueuer_gen = val_enqueuer.get()
elif val_gen:
val_data = validation_data
if isinstance(val_data, Sequence):
if is_sequence(val_data):
val_enqueuer_gen = iter_sequence_infinite(val_data)
validation_steps = validation_steps or len(val_data)
else:
Expand All @@ -149,7 +151,7 @@ def fit_generator(model,
cbk.validation_data = val_data

if workers > 0:
if is_sequence:
if use_sequence_api:
enqueuer = OrderedEnqueuer(
generator,
use_multiprocessing=use_multiprocessing,
Expand All @@ -161,7 +163,7 @@ def fit_generator(model,
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
if use_sequence_api:
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
Expand Down Expand Up @@ -284,15 +286,15 @@ def evaluate_generator(model, generator,
steps_done = 0
outs_per_batch = []
batch_sizes = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
if steps is None:
if is_sequence:
if use_sequence_api:
steps = len(generator)
else:
raise ValueError('`steps=None` is only valid for a generator'
Expand All @@ -303,7 +305,7 @@ def evaluate_generator(model, generator,

try:
if workers > 0:
if is_sequence:
if use_sequence_api:
enqueuer = OrderedEnqueuer(
generator,
use_multiprocessing=use_multiprocessing)
Expand All @@ -314,7 +316,7 @@ def evaluate_generator(model, generator,
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
if use_sequence_api:
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
Expand Down Expand Up @@ -387,15 +389,15 @@ def predict_generator(model, generator,

steps_done = 0
all_outs = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
if steps is None:
if is_sequence:
if use_sequence_api:
steps = len(generator)
else:
raise ValueError('`steps=None` is only valid for a generator'
Expand All @@ -406,7 +408,7 @@ def predict_generator(model, generator,

try:
if workers > 0:
if is_sequence:
if use_sequence_api:
enqueuer = OrderedEnqueuer(
generator,
use_multiprocessing=use_multiprocessing)
Expand All @@ -417,7 +419,7 @@ def predict_generator(model, generator,
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
if use_sequence_api:
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
Expand Down
15 changes: 15 additions & 0 deletions keras/engine/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .. import backend as K
from .. import losses
from ..utils import Sequence
from ..utils.generic_utils import to_list


Expand Down Expand Up @@ -589,3 +590,17 @@ def iter_sequence_infinite(seq):
while True:
for item in seq:
yield item


def is_sequence(seq):
"""Determine if an object follows the Sequence API.
# Arguments
seq: a possible Sequence object
# Returns
boolean, whether the object follows the Sequence API.
"""
# TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
return (getattr(seq, 'use_sequence_api', False)
or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))
2 changes: 2 additions & 0 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def __getitem__(self, idx):
```
"""

use_sequence_api = True

@abstractmethod
def __getitem__(self, index):
"""Gets batch at position `index`.
Expand Down
35 changes: 35 additions & 0 deletions tests/integration_tests/test_image_data_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pytest

from keras.preprocessing.image import ImageDataGenerator
from keras.utils.test_utils import get_test_data
from keras.models import Sequential
from keras import layers
Expand Down Expand Up @@ -41,5 +42,39 @@ def test_image_classification():
model = Sequential.from_config(config)


def test_image_data_generator_training():
np.random.seed(1337)
img_gen = ImageDataGenerator(rescale=1.) # Dummy ImageDataGenerator
input_shape = (16, 16, 3)
(x_train, y_train), (x_test, y_test) = get_test_data(num_train=500,
num_test=200,
input_shape=input_shape,
classification=True,
num_classes=4)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

model = Sequential([
layers.Conv2D(filters=8, kernel_size=3,
activation='relu',
input_shape=input_shape),
layers.MaxPooling2D(pool_size=2),
layers.Conv2D(filters=4, kernel_size=(3, 3),
activation='relu', padding='same'),
layers.GlobalAveragePooling2D(),
layers.Dense(y_test.shape[-1], activation='softmax')
])
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
history = model.fit_generator(img_gen.flow(x_train, y_train, batch_size=16),
epochs=10,
validation_data=img_gen.flow(x_test, y_test,
batch_size=16),
verbose=0)
assert history.history['val_acc'][-1] > 0.75
model.evaluate_generator(img_gen.flow(x_train, y_train, batch_size=16))


if __name__ == '__main__':
pytest.main([__file__])
2 changes: 1 addition & 1 deletion tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from keras import backend as K

pytestmark = pytest.mark.skipif(
K.backend() == 'tensorflow',
K.backend() == 'tensorflow' and 'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Temporarily disabled until the use_multiprocessing problem is solved')

if sys.version_info < (3,):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from keras import backend as K

pytestmark = pytest.mark.skipif(
K.backend() == 'tensorflow',
K.backend() == 'tensorflow' and 'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Temporarily disabled until the use_multiprocessing problem is solved')

STEPS_PER_EPOCH = 100
Expand Down

0 comments on commit d6b5c5e

Please sign in to comment.