Skip to content

Commit

Permalink
Updated docstring for Model.compute_metrics(). (keras-team#19980)
Browse files Browse the repository at this point in the history
Added details and a more complete example of how to use custom metrics.

Also corrected inaccurate statement about the tracking of custom metrics.
  • Loading branch information
hertschuh authored Jul 12, 2024
1 parent de066e5 commit 998b392
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __init__(self, *args, **kwargs):
self.loss_tracker = metrics.Mean(name='loss')
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
loss = ops.means((y_pred - y) ** 2)
loss = ops.mean((y_pred - y) ** 2)
loss += ops.sum(self.losses)
self.loss_tracker.update_state(loss)
return loss
Expand Down Expand Up @@ -408,22 +408,28 @@ def compute_metrics(self, x, y, y_pred, sample_weight=None):
"""Update metric states and collect all metrics to be returned.
Subclasses can optionally override this method to provide custom metric
updating and collection logic.
updating and collection logic. Custom metrics are not passed in
`compile()`, they can be created in `__init__` or `build`. They are
automatically tracked and returned by `self.metrics`.
Example:
```python
class MyModel(Sequential):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.custom_metric = MyMetric(name="custom_metric")
def compute_metrics(self, x, y, y_pred, sample_weight):
# This super call updates `self.compiled_metrics` and returns
# This super call updates metrics from `compile` and returns
# results for all metrics listed in `self.metrics`.
metric_results = super().compute_metrics(
x, y, y_pred, sample_weight)
# Note that `self.custom_metric` is not listed
# in `self.metrics`.
# `metric_results` contains the previous result for
# `custom_metric`, this is where we update it.
self.custom_metric.update_state(x, y, y_pred, sample_weight)
metric_results['metric_name'] = self.custom_metric.result()
metric_results['custom_metric'] = self.custom_metric.result()
return metric_results
```
Expand Down

0 comments on commit 998b392

Please sign in to comment.