Skip to content

Commit

Permalink
Allow inputs of model training/testing methods to be a list of int/fl…
Browse files Browse the repository at this point in the history
…oat values.
  • Loading branch information
fchollet committed Jan 30, 2018
1 parent b3e10cd commit db0707b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
10 changes: 6 additions & 4 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,19 @@ def _standardize_input_data(data, names, shapes=None,
if isinstance(data, dict):
try:
data = [data[x].values if data[x].__class__.__name__ == 'DataFrame' else data[x] for x in names]
data = [np.expand_dims(x, 1) if x.ndim == 1 else x for x in data]
except KeyError as e:
raise ValueError(
'No data provided for "' + e.args[0] + '". Need data '
'for each key in: ' + str(names))
elif isinstance(data, list):
data = [x.values if x.__class__.__name__ == 'DataFrame' else x for x in data]
data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data]
if len(names) == 1 and data and isinstance(data[0], (float, int)):
data = [np.asarray(data)]
else:
data = [x.values if x.__class__.__name__ == 'DataFrame' else x for x in data]
else:
data = data.values if data.__class__.__name__ == 'DataFrame' else data
data = [np.expand_dims(data, 1)] if data.ndim == 1 else [data]
data = [data]
data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data]

if len(data) != len(names):
if data and hasattr(data[0], 'shape'):
Expand Down
11 changes: 11 additions & 0 deletions tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,17 @@ def test_trainable_argument():
assert_allclose(out, out_2)


@keras_test
def test_with_list_as_targets():
model = Sequential()
model.add(Dense(1, input_dim=3, trainable=False))
model.compile('rmsprop', 'mse')

x = np.random.random((2, 3))
y = [0, 1]
model.train_on_batch(x, y)


@keras_test
def test_check_not_failing():
a = np.random.random((2, 1, 3))
Expand Down

0 comments on commit db0707b

Please sign in to comment.