Skip to content

Commit

Permalink
training.py _slice_arrays() fix crash when arrays are None (keras-tea…
Browse files Browse the repository at this point in the history
…m#7069)

* training.py _slice_arrays() fix crash when arrays are None

* training.py test _slice_arrays()
  • Loading branch information
ahundt authored and fchollet committed Jun 22, 2017
1 parent de73eda commit 1b53999
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
12 changes: 8 additions & 4 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,21 +387,25 @@ def _slice_arrays(arrays, start=None, stop=None):
# Returns
A slice of the array(s).
"""
if isinstance(arrays, list):
if arrays is None:
return [None]
elif isinstance(arrays, list):
if hasattr(start, '__len__'):
# hdf5 datasets only support list objects as indices
if hasattr(start, 'shape'):
start = start.tolist()
return [x[start] for x in arrays]
return [None if x is None else x[start] for x in arrays]
else:
return [x[start:stop] for x in arrays]
return [None if x is None else x[start:stop] for x in arrays]
else:
if hasattr(start, '__len__'):
if hasattr(start, 'shape'):
start = start.tolist()
return arrays[start]
else:
elif hasattr(start, '__getitem__'):
return arrays[start:stop]
else:
return [None]


def _weighted_masked_objective(fn):
Expand Down
22 changes: 22 additions & 0 deletions tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from keras.engine.training import Model
from keras.engine.training import _check_loss_and_target_compatibility
from keras.engine.training import _weighted_masked_objective
from keras.engine.training import _slice_arrays
from keras.models import Sequential
from keras import backend as K
from keras.utils import Sequence
Expand All @@ -29,6 +30,27 @@ def __getitem__(self, idx):
np.random.random((self.batch_size, 3))]


@keras_test
def test_slice_arrays():
input_a = np.random.random((10, 3))
_slice_arrays(None)
_slice_arrays(input_a, 0)
_slice_arrays(input_a, 0, 1)
_slice_arrays(input_a, stop=2)
input_a = [None, [1, 1], None, [1, 1]]
_slice_arrays(input_a, 0)
_slice_arrays(input_a, 0, 1)
_slice_arrays(input_a, stop=2)
input_a = [None]
_slice_arrays(input_a, 0)
_slice_arrays(input_a, 0, 1)
_slice_arrays(input_a, stop=2)
input_a = None
_slice_arrays(input_a, 0)
_slice_arrays(input_a, 0, 1)
_slice_arrays(input_a, stop=2)


@keras_test
def test_weighted_masked_objective():
a = Input(shape=(3,), name='input_a')
Expand Down

0 comments on commit 1b53999

Please sign in to comment.