Skip to content

Commit

Permalink
Add precision and recall to fairness indicators metrics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 447761681
  • Loading branch information
embr authored and tfx-copybara committed May 10, 2022
1 parent b499acf commit dfe9d55
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
'true_positive_rate', 'true_negative_rate',
'positive_rate', 'negative_rate',
'false_discovery_rate',
'false_omission_rate')
'false_omission_rate', 'precision', 'recall')

DEFAULT_THRESHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)

Expand Down Expand Up @@ -113,6 +113,8 @@ def result(
(num_positives + num_negatives) or float('nan'))
nr = (metric.tn[i] + metric.fn[i]) / (
(num_positives + num_negatives) or float('nan'))
precision = metric.tp[i] / ((metric.tp[i] + metric.fp[i]) or float('nan'))
recall = metric.tp[i] / ((metric.tp[i] + metric.fn[i]) or float('nan'))

fdr = metric.fp[i] / ((metric.fp[i] + metric.tp[i]) or float('nan'))
fomr = metric.fn[i] / ((metric.fn[i] + metric.tn[i]) or float('nan'))
Expand All @@ -131,6 +133,9 @@ def result(
['false_discovery_rate']] = fdr
output[metric_key_by_name_by_threshold[threshold]
['false_omission_rate']] = fomr
output[metric_key_by_name_by_threshold[threshold]['precision']] = (
precision)
output[metric_key_by_name_by_threshold[threshold]['recall']] = recall

return output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def check_result(got):
self.assertLen(got, 1)
got_slice_key, got_metrics = got[0]
self.assertEqual(got_slice_key, ())
self.assertLen(got_metrics, 16) # 2 thresholds * 8 metrics
self.assertDictElementsAlmostEqual(
np.testing.assert_equal(
got_metrics, {
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]'
Expand Down Expand Up @@ -105,6 +104,12 @@ def check_result(got):
name='fairness_indicators_metrics/[email protected]'
):
0.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]'):
2.0 / 3.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]'):
1.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]'
):
Expand Down Expand Up @@ -134,7 +139,13 @@ def check_result(got):
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]'
):
1.0 / 3.0
1.0 / 3.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]'):
1.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]'):
0.5
})
except AssertionError as err:
raise util.BeamAssertException(err)
Expand Down Expand Up @@ -176,7 +187,7 @@ def check_result(got):
self.assertLen(got, 1)
got_slice_key, got_metrics = got[0]
self.assertEqual(got_slice_key, ())
self.assertLen(got_metrics, 8) # 1 threshold * 8 metrics
self.assertLen(got_metrics, 10) # 1 threshold * 10 metrics
self.assertTrue(
math.isnan(got_metrics[metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]')]))
Expand All @@ -190,10 +201,10 @@ def check_result(got):
util.assert_that(result, check_result, label='result')

@parameterized.named_parameters(
('_default_threshold', {}, 72, ()),
('_default_threshold', {}, 90, ()),
('_thresholds_with_different_digits', {
'thresholds': [0.1, 0.22, 0.333]
}, 24, (metric_types.MetricKey(
}, 30, (metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
example_weighted=True),
metric_types.MetricKey(
Expand Down Expand Up @@ -276,6 +287,14 @@ def check_result(got):
name='fairness_indicators_metrics/[email protected]',
example_weighted=True):
0.25,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
example_weighted=True):
float('nan'),
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
example_weighted=True):
float('nan'),
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
example_weighted=True):
Expand All @@ -284,6 +303,17 @@ def check_result(got):
name='fairness_indicators_metrics/[email protected]',
example_weighted=True):
1.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
example_weighted=True):
0.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
example_weighted=True):
0.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]', example_weighted=True):
float('nan'),
}), ('_has_model_name', [{
'labels': np.array([0.0]),
'predictions': {
Expand Down Expand Up @@ -315,6 +345,16 @@ def check_result(got):
model_name='model1',
example_weighted=True):
0.25,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
model_name='model1',
example_weighted=True):
float('nan'),
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
model_name='model1',
example_weighted=True):
float('nan'),
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
model_name='model1',
Expand All @@ -325,6 +365,21 @@ def check_result(got):
model_name='model1',
example_weighted=True):
1.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
model_name='model1',
example_weighted=True):
0.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
model_name='model1',
example_weighted=True):
0.0,
metric_types.MetricKey(
name='fairness_indicators_metrics/[email protected]',
model_name='model1',
example_weighted=True):
float('nan'),
}))
def testFairessIndicatorsMetricsWithInput(self, input_examples,
computations_kwargs,
Expand Down Expand Up @@ -360,10 +415,7 @@ def check_result(got):
self.assertLen(got, 1)
got_slice_key, got_metrics = got[0]
self.assertEqual(got_slice_key, ())
self.assertLen(got_metrics, 8) # 1 threshold * 8 metrics
for metrics_key in expected_result:
self.assertEqual(got_metrics[metrics_key],
expected_result[metrics_key])
np.testing.assert_equal(got_metrics, expected_result)
except AssertionError as err:
raise util.BeamAssertException(err)

Expand Down

0 comments on commit dfe9d55

Please sign in to comment.