From db0707b7b25d16d3a26c8c9651987a7d0a441e5b Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 29 Jan 2018 18:30:15 -0800 Subject: [PATCH] Allow inputs of model training/testing methods to be a list of int/float values. --- keras/engine/training.py | 10 ++++++---- tests/keras/engine/test_training.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/keras/engine/training.py b/keras/engine/training.py index 5a5d85545d6..22bb77ca7e9 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -60,17 +60,19 @@ def _standardize_input_data(data, names, shapes=None, if isinstance(data, dict): try: data = [data[x].values if data[x].__class__.__name__ == 'DataFrame' else data[x] for x in names] - data = [np.expand_dims(x, 1) if x.ndim == 1 else x for x in data] except KeyError as e: raise ValueError( 'No data provided for "' + e.args[0] + '". Need data ' 'for each key in: ' + str(names)) elif isinstance(data, list): - data = [x.values if x.__class__.__name__ == 'DataFrame' else x for x in data] - data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data] + if len(names) == 1 and data and isinstance(data[0], (float, int)): + data = [np.asarray(data)] + else: + data = [x.values if x.__class__.__name__ == 'DataFrame' else x for x in data] else: data = data.values if data.__class__.__name__ == 'DataFrame' else data - data = [np.expand_dims(data, 1)] if data.ndim == 1 else [data] + data = [data] + data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data] if len(data) != len(names): if data and hasattr(data[0], 'shape'): diff --git a/tests/keras/engine/test_training.py b/tests/keras/engine/test_training.py index 6846a682d70..a928bdc65f5 100644 --- a/tests/keras/engine/test_training.py +++ b/tests/keras/engine/test_training.py @@ -568,6 +568,17 @@ def test_trainable_argument(): assert_allclose(out, out_2) +@keras_test +def test_with_list_as_targets(): + model = Sequential() + model.add(Dense(1, input_dim=3, trainable=False)) + model.compile('rmsprop', 'mse') + + x = np.random.random((2, 3)) + y = [0, 1] + model.train_on_batch(x, y) + + @keras_test def test_check_not_failing(): a = np.random.random((2, 1, 3))