Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Aug 25, 2019
1 parent 2ccb69d commit 21efc67
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
6 changes: 4 additions & 2 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,10 @@ def _prepare_total_loss(self, masks=None):
y_true, y_pred, sample_weight=sample_weight)

if len(self.outputs) > 1:
update_ops = self._output_loss_metrics[i].update_state(output_loss)
self._metric_updates += update_ops
# TODO
# update_ops = self._output_loss_metrics[i].update_state(output_loss)
# self._metric_updates += update_ops
self._output_loss_metrics[i](output_loss)

if total_loss is None:
total_loss = loss_weight * output_loss
Expand Down
13 changes: 5 additions & 8 deletions tests/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_sum(self):
assert K.eval(m.total) == 100

# check update_state() and result() + state accumulation + tensor input
K.eval(m.update_state([1, 5]))
assert np.isclose(K.eval(m.result()), 106)
result = m([1, 5])
assert np.isclose(K.eval(result), 106)
assert K.eval(m.total) == 106 # 100 + 1 + 5

# check reset_states()
Expand Down Expand Up @@ -97,9 +97,8 @@ def test_mean(self):
assert K.eval(m.count) == 1

# check update_state() and result()
update_op = m.update_state([1, 5])
K.eval(update_op)
assert np.isclose(K.eval(m.result()), 106 / 3)
result = m([1, 5])
assert np.isclose(K.eval(result), 106 / 3)
assert K.eval(m.total) == 106 # 100 + 1 + 5
assert K.eval(m.count) == 3

Expand Down Expand Up @@ -195,9 +194,7 @@ def test_unweighted(self):
y_true = ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))
y_pred = ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))

update_op = mse_obj.update_state(y_true, y_pred)
K.eval(update_op)
result = mse_obj.result()
result = mse_obj(y_true, y_pred)
np.isclose(0.5, K.eval(result), atol=1e-5)

def test_weighted(self):
Expand Down

0 comments on commit 21efc67

Please sign in to comment.