Skip to content

Commit

Permalink
General stateful metrics fixes (keras-team#9446)
Browse files Browse the repository at this point in the history
* Require stateful metrics layers to be actually stateful

* Prevent stateful metrics to leak np.floats to the History object

* Progbar: Format stateful metrics values as floats alike other metrics

* test_stateful_metrics: Also test validation set evaluation

This makes sure the metric is reset before evaluating valset.

* Add support for stateful metrics in fit_generator() and evaluate_generator()

* Document stateful metrics

It would be even better to have full-fledged stateful layers documentations, but I lack the knowledge and experience to explain that well.

* evaluate_generator(): Do not leak np.float to History here either

* Revert stateful metrics documentation until the API stabilizes

* Progbar: Explain stateful metrics handling

* Model.evaluate_generator(): More consistent stateful metrics handling

Use metrics_names, rather than metrics + index juggling to skip loss.

Make loss-only output handling consistent with other Model methods.

all_outs -> outs_per_batch to avoid confusion, all_outs has swapped dimensions in predict_generator().
  • Loading branch information
pasky authored and fchollet committed Mar 22, 2018
1 parent 369854e commit 2f4685c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 19 deletions.
45 changes: 31 additions & 14 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def handle_metrics(metrics, weights=None):

# Keep track of state updates created by
# stateful metrics (i.e. metrics layers).
if isinstance(metric_fn, Layer):
if isinstance(metric_fn, Layer) and metric_fn.stateful:
self.stateful_metric_names.append(metric_name)
self.metrics_updates += metric_fn.updates

Expand Down Expand Up @@ -1175,7 +1175,7 @@ def _fit_loop(self, f, ins, out_labels=None, batch_size=None,
for epoch in range(initial_epoch, epochs):
# Reset stateful metrics
for m in self.metrics:
if isinstance(m, Layer):
if isinstance(m, Layer) and m.stateful:
m.reset_states()
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
Expand Down Expand Up @@ -1364,7 +1364,7 @@ def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None):

if hasattr(self, 'metrics'):
for m in self.metrics:
if isinstance(m, Layer):
if isinstance(m, Layer) and m.stateful:
m.reset_states()
stateful_metric_indices = [
i for i, name in enumerate(self.metrics_names)
Expand Down Expand Up @@ -1398,7 +1398,7 @@ def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None):
outs.append(0.)
for i, batch_out in enumerate(batch_outs):
if i in stateful_metric_indices:
outs[i] = batch_out
outs[i] = float(batch_out)
else:
outs[i] += batch_out
else:
Expand Down Expand Up @@ -2185,6 +2185,9 @@ def generate_arrays_from_file(path):
# Construct epoch logs.
epoch_logs = {}
while epoch < epochs:
for m in self.metrics:
if isinstance(m, Layer) and m.stateful:
m.reset_states()
callbacks.on_epoch_begin(epoch)
steps_done = 0
batch_index = 0
Expand Down Expand Up @@ -2320,9 +2323,20 @@ def evaluate_generator(self, generator, steps=None,
"""
self._make_test_function()

stateful_metric_indices = []
if hasattr(self, 'metrics'):
for i, m in enumerate(self.metrics):
if isinstance(m, Layer) and m.stateful:
m.reset_states()
stateful_metric_indices = [
i for i, name in enumerate(self.metrics_names)
if str(name) in self.stateful_metric_names]
else:
stateful_metric_indices = []

steps_done = 0
wait_time = 0.01
all_outs = []
outs_per_batch = []
batch_sizes = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
Expand Down Expand Up @@ -2376,6 +2390,9 @@ def evaluate_generator(self, generator, steps=None,
'or (x, y). Found: ' +
str(generator_output))
outs = self.test_on_batch(x, y, sample_weight=sample_weight)
if not isinstance(outs, list):
outs = [outs]
outs_per_batch.append(outs)

if isinstance(x, list):
batch_size = x[0].shape[0]
Expand All @@ -2386,7 +2403,6 @@ def evaluate_generator(self, generator, steps=None,
if batch_size == 0:
raise ValueError('Received an empty batch. '
'Batches should at least contain one item.')
all_outs.append(outs)

steps_done += 1
batch_sizes.append(batch_size)
Expand All @@ -2395,15 +2411,16 @@ def evaluate_generator(self, generator, steps=None,
if enqueuer is not None:
enqueuer.stop()

if not isinstance(outs, list):
return np.average(np.asarray(all_outs),
weights=batch_sizes)
else:
averages = []
for i in range(len(outs)):
averages.append(np.average([out[i] for out in all_outs],
averages = []
for i in range(len(outs)):
if i not in stateful_metric_indices:
averages.append(np.average([out[i] for out in outs_per_batch],
weights=batch_sizes))
return averages
else:
averages.append(float(outs_per_batch[-1][i]))
if len(averages) == 1:
return averages[0]
return averages

@interfaces.legacy_generator_methods_support
def predict_generator(self, generator, steps=None,
Expand Down
5 changes: 4 additions & 1 deletion keras/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,10 @@ def update(self, current, values=None):
self._values[k][0] += v * (current - self._seen_so_far)
self._values[k][1] += (current - self._seen_so_far)
else:
self._values[k] = v
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self._values[k] = [v, 1]
self._seen_so_far = current

now = time.time()
Expand Down
35 changes: 31 additions & 4 deletions tests/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ class BinaryTruePositives(keras.layers.Layer):
Assumes predictions and targets of shape `(samples, 1)`.
# Arguments
threshold: Float, lower limit on prediction value that counts as a
positive class prediction.
name: String, name for the metric.
"""

def __init__(self, name='true_positives', **kwargs):
super(BinaryTruePositives, self).__init__(name=name, **kwargs)
self.stateful = True
self.true_positives = K.variable(value=0, dtype='int32')

def reset_states(self):
Expand Down Expand Up @@ -162,11 +161,16 @@ def __call__(self, y_true, y_pred):
loss='binary_crossentropy',
metrics=['acc', metric_fn])

# Test fit, evaluate
samples = 1000
x = np.random.random((samples, 2))
y = np.random.randint(2, size=(samples, 1))
model.fit(x, y, epochs=1, batch_size=10)

val_samples = 10
val_x = np.random.random((val_samples, 2))
val_y = np.random.randint(2, size=(val_samples, 1))

# Test fit and evaluate
history = model.fit(x, y, validation_data=(val_x, val_y), epochs=1, batch_size=10)
outs = model.evaluate(x, y, batch_size=10)
preds = model.predict(x)

Expand All @@ -176,6 +180,29 @@ def ref_true_pos(y_true, y_pred):
# Test correctness (e.g. updates should have been run)
np.testing.assert_allclose(outs[2], ref_true_pos(y, preds), atol=1e-5)

# Test correctness of the validation metric computation
val_preds = model.predict(val_x)
val_outs = model.evaluate(val_x, val_y, batch_size=10)
np.testing.assert_allclose(val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
np.testing.assert_allclose(val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)

# Test with generators
gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(x, y)]
val_gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(val_x, val_y)]
history = model.fit_generator(iter(gen), epochs=1, steps_per_epoch=samples,
validation_data=iter(val_gen), validation_steps=val_samples)
outs = model.evaluate_generator(iter(gen), steps=samples)
preds = model.predict_generator(iter(gen), steps=samples)

# Test correctness of the metric re ref_true_pos()
np.testing.assert_allclose(outs[2], ref_true_pos(y, preds), atol=1e-5)

# Test correctness of the validation metric computation
val_preds = model.predict_generator(iter(val_gen), steps=val_samples)
val_outs = model.evaluate_generator(iter(val_gen), steps=val_samples)
np.testing.assert_allclose(val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
np.testing.assert_allclose(val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 2f4685c

Please sign in to comment.