Skip to content

Commit

Permalink
[WIP] Add support for sparse arrays in fit, predict and evaluate. (ke…
Browse files Browse the repository at this point in the history
…ras-team#8930)

* Support of sparse arrays as input in fit, predict and evaluate.

* Fixed small out of range bug.

* Fixed formating.

* Removed the conversion loop if training from symbolic tensors.

* Now we check for the whole feed instead of just inputs for the conversion.

* Improved tests.

* Refactoring tests.

* Duplicated tests to check sparse placeholder.

* Fixed bug. In predict(), the variable ins only have inputs.

* Fixed pep8.
  • Loading branch information
gabrieldemarmiesse authored and fchollet committed Jan 1, 2018
1 parent eb81c5a commit 45c838c
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 7 deletions.
30 changes: 30 additions & 0 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from keras.utils import Sequence
from keras.utils import GeneratorEnqueuer
from keras.utils import OrderedEnqueuer
from scipy.sparse import issparse

try:
import queue
Expand Down Expand Up @@ -1157,6 +1158,13 @@ def _fit_loop(self, f, ins, out_labels=None, batch_size=None,
for cbk in callbacks:
cbk.validation_data = val_ins

# To prevent a slowdown, we find beforehand the arrays that need conversion.
feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
indices_for_conversion_to_dense = []
for i in range(len(feed)):
if issparse(ins[i]) and not K.is_sparse(feed[i]):
indices_for_conversion_to_dense.append(i)

for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
Expand Down Expand Up @@ -1210,6 +1218,9 @@ def _fit_loop(self, f, ins, out_labels=None, batch_size=None,
batch_logs['batch'] = batch_index
batch_logs['size'] = len(batch_ids)
callbacks.on_batch_begin(batch_index, batch_logs)
for i in indices_for_conversion_to_dense:
ins_batch[i] = ins_batch[i].toarray()

outs = f(ins_batch)
if not isinstance(outs, list):
outs = [outs]
Expand Down Expand Up @@ -1261,6 +1272,12 @@ def _predict_loop(self, f, ins, batch_size=32, verbose=0, steps=None):
progbar = Progbar(target=steps)
else:
progbar = Progbar(target=num_samples)

indices_for_conversion_to_dense = []
for i in range(len(self._feed_inputs)):
if issparse(ins[i]) and not K.is_sparse(self._feed_inputs[i]):
indices_for_conversion_to_dense.append(i)

if steps is not None:
# Step-based predictions.
# Since we do not know how many samples
Expand Down Expand Up @@ -1296,6 +1313,9 @@ def _predict_loop(self, f, ins, batch_size=32, verbose=0, steps=None):
ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = _slice_arrays(ins, batch_ids)
for i in indices_for_conversion_to_dense:
ins_batch[i] = ins_batch[i].toarray()

batch_outs = f(ins_batch)
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
Expand Down Expand Up @@ -1339,6 +1359,14 @@ def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None):
progbar = Progbar(target=steps)
else:
progbar = Progbar(target=num_samples)

# To prevent a slowdown, we find beforehand the arrays that need conversion.
feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
indices_for_conversion_to_dense = []
for i in range(len(feed)):
if issparse(ins[i]) and not K.is_sparse(feed[i]):
indices_for_conversion_to_dense.append(i)

if steps is not None:
for step in range(steps):
batch_outs = f(ins)
Expand Down Expand Up @@ -1366,6 +1394,8 @@ def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None):
ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = _slice_arrays(ins, batch_ids)
for i in indices_for_conversion_to_dense:
ins_batch[i] = ins_batch[i].toarray()

batch_outs = f(ins_batch)
if isinstance(batch_outs, list):
Expand Down
81 changes: 74 additions & 7 deletions tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,83 @@ def gen_data(batch_sz):
assert all(['Sequence' not in str(w_.message) for w_ in w]), 'A warning was raised for Sequence.'


@keras_test
def test_sparse_input_target_fit():
test_inputs = [sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
test_outputs = [sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
in1 = Input(shape=(3,))
in2 = Input(shape=(3,))
out1 = Dropout(0.5, name='dropout')(in1)
out2 = Dense(4, name='dense_1')(in2)
model = Model([in1, in2], [out1, out2])
model.compile('rmsprop', 'mse')
model.fit(test_inputs, test_outputs, epochs=1, batch_size=2, validation_split=0.2)


@keras_test
def test_sparse_input_target_evaluate():
test_inputs = [sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
test_outputs = [sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
in1 = Input(shape=(3,))
in2 = Input(shape=(3,))
out1 = Dropout(0.5, name='dropout')(in1)
out2 = Dense(4, name='dense_1')(in2)
model = Model([in1, in2], [out1, out2])
model.compile('rmsprop', 'mse')
model.evaluate(test_inputs, test_outputs, batch_size=2)


@keras_test
def test_sparse_input_predict():
test_inputs = [sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
in1 = Input(shape=(3,))
in2 = Input(shape=(3,))
out1 = Dropout(0.5, name='dropout')(in1)
out2 = Dense(4, name='dense_1')(in2)
model = Model([in1, in2], [out1, out2])
model.compile('rmsprop', 'mse')
model.predict(test_inputs, batch_size=2)


@pytest.mark.skipif(K.backend() != 'tensorflow', reason='sparse operations supported only by TensorFlow')
@keras_test
def test_sparse_placeholder_fit():
test_inputs = [sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
test_outputs = [sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
in1 = Input(shape=(3,))
in2 = Input(shape=(3,), sparse=True)
out1 = Dropout(0.5, name='dropout')(in1)
out2 = Dense(4, name='dense_1')(in2)
model = Model([in1, in2], [out1, out2])
model.compile('rmsprop', 'mse')
model.fit(test_inputs, test_outputs, epochs=1, batch_size=2, validation_split=0.2)


@pytest.mark.skipif(K.backend() != 'tensorflow', reason='sparse operations supported only by TensorFlow')
@keras_test
def test_sparse_placeholder_evaluate():
test_inputs = [sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
test_outputs = [sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
in1 = Input(shape=(3,))
in2 = Input(shape=(3,), sparse=True)
out1 = Dropout(0.5, name='dropout')(in1)
out2 = Dense(4, name='dense_1')(in2)
model = Model([in1, in2], [out1, out2])
model.compile('rmsprop', 'mse')
model.evaluate(test_inputs, test_outputs, batch_size=2)


@pytest.mark.skipif(K.backend() != 'tensorflow', reason='sparse operations supported only by TensorFlow')
@keras_test
def test_sparse_input_validation_split():
test_input = sparse.random(6, 3, density=0.25).tocsr()
in1 = Input(shape=(3,), sparse=True)
out1 = Dense(4)(in1)
test_output = np.random.random((6, 4))
model = Model(in1, out1)
def test_sparse_placeholder_predict():
test_inputs = [sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
in1 = Input(shape=(3,))
in2 = Input(shape=(3,), sparse=True)
out1 = Dropout(0.5, name='dropout')(in1)
out2 = Dense(4, name='dense_1')(in2)
model = Model([in1, in2], [out1, out2])
model.compile('rmsprop', 'mse')
model.fit(test_input, test_output, epochs=1, batch_size=2, validation_split=0.2)
model.predict(test_inputs, batch_size=2)


@keras_test
Expand Down

0 comments on commit 45c838c

Please sign in to comment.