Skip to content

Commit

Permalink
Add support for non-scalar example weights.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 401122273
  • Loading branch information
mdreves authored and tfx-copybara committed Oct 6, 2021
1 parent 7193c9c commit d9ca830
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 30 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

## Bug fixes and other Changes

* Added support for example_weights that are arrays.

## Breaking Changes

## Deprecations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def get_by_keys(value: Any, keys: List[str]) -> Any:
output_name=output_name,
fractional_labels=False, # Labels are ignored for flip counts.
flatten=False, # Flattened below
allow_none=True)) # Allow None labels
allow_none=True, # Allow None labels
require_single_example_weight=True))

if prediction.size != counterfactual_prediction.size:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_model_analysis/metrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def add_input(self, accumulator: Dict[int, float],
model_name=self._key.model_name,
output_name=self._key.output_name,
flatten=False,
allow_none=True)):
allow_none=True,
require_single_example_weight=True)):
if example_weight is None:
example_weight = 1.0
else:
Expand Down
54 changes: 31 additions & 23 deletions tensorflow_model_analysis/metrics/metric_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def to_label_prediction_example_weight(
flatten: bool = True,
squeeze: bool = True,
allow_none: bool = False,
require_single_example_weight: bool = False
) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""Yields label, prediction, and example weights for use in calculations.
Expand Down Expand Up @@ -363,12 +364,16 @@ def to_label_prediction_example_weight(
yielded values are always arrays of size 1. For example, multi-class /
multi-label outputs would be converted into label and prediction pairs
that could then be processed by a binary classification metric in order to
compute a micro average over all classes.
compute a micro average over all classes. If the example weight is not a
scalar, then they will be flattened as well, otherwise the same example
weight value will be output for each pair of labels and predictions.
squeeze: True to squeeze any outputs that have rank > 1. This transforms
outputs such as np.array([[1]]) to np.array([1]).
allow_none: True to allow labels or predictions with None values to be
returned. When used, the values will be returned as empty np.ndarrays. The
example weight will always be non-empty.
require_single_example_weight: True to require that the example_weight be a
single value.
Yields:
Tuple of (label, prediction, example_weight).
Expand Down Expand Up @@ -453,11 +458,7 @@ def optionally_get_by_keys(value: Any, keys: List[Text]) -> Any:
'error in the pipeline.')

example_weight = util.to_numpy(example_weight)

# Query based metrics group by a query_id which will result in the
# example_weight being replicated once for each matching example in the
# group. When this happens convert the example_weight back to a single value
if example_weight.size > 1:
if require_single_example_weight and example_weight.size > 1:
example_weight = example_weight.flatten()
if not np.all(example_weight == example_weight[0]):
raise ValueError(
Expand Down Expand Up @@ -498,9 +499,23 @@ def optionally_get_by_keys(value: Any, keys: List[Text]) -> Any:
label = label.reshape((1,))
if prediction is not None and not prediction.shape:
prediction = prediction.reshape((1,))
if example_weight is not None and not example_weight.shape:
if not example_weight.shape:
example_weight = example_weight.reshape((1,))

label = label if label is not None else np.array([])
prediction = prediction if prediction is not None else np.array([])

flatten_size = prediction.size or label.size
if flatten:
if example_weight.size == 1:
example_weight = np.array(
[float(example_weight) for i in range(flatten_size)])
elif example_weight.size != flatten_size:
raise ValueError(
'example_weight size does not match the size of labels and '
'predictions: label={}, prediction={}, example_weight={}'.format(
label, prediction, example_weight))

if class_weights:
if not flatten:
raise ValueError(
Expand All @@ -510,22 +525,14 @@ def optionally_get_by_keys(value: Any, keys: List[Text]) -> Any:
"averaging being applied to metrics that don't support micro "
'averaging')
example_weight = np.array([
float(example_weight) *
class_weights[i] if i in class_weights else 0.0
for i in range(prediction.shape[-1] or label.shape[-1])
example_weight[i] * class_weights[i] if i in class_weights else 0.0
for i in range(flatten_size)
])
elif flatten:
example_weight = np.array([
float(example_weight)
for i in range(prediction.shape[-1] or label.shape[-1])
])

label = label if label is not None else np.array([])
prediction = prediction if prediction is not None else np.array([])

def yield_results(label, prediction, example_weight):
if (not flatten or (label.size == 0 and prediction.size == 0) or
(label.size == 1 and prediction.size == 1)):
(label.size == 1 and prediction.size == 1 and
example_weight.size == 1)):
if squeeze:
yield _squeeze(label), _squeeze(prediction), _squeeze(example_weight)
else:
Expand All @@ -536,19 +543,20 @@ def yield_results(label, prediction, example_weight):
elif prediction.size == 0:
for l, w in zip(label.flatten(), example_weight.flatten()):
yield np.array([l]), prediction, np.array([w])
elif label.size == prediction.size:
elif label.size == prediction.size and label.size == example_weight.size:
for l, p, w in zip(label.flatten(), prediction.flatten(),
example_weight.flatten()):
yield np.array([l]), np.array([p]), np.array([w])
elif label.shape[-1] == 1:
elif label.shape[-1] == 1 and prediction.size == example_weight.size:
label = one_hot(label, prediction)
for l, p, w in zip(label.flatten(), prediction.flatten(),
example_weight.flatten()):
yield np.array([l]), np.array([p]), np.array([w])
else:
raise ValueError(
f'unable to pair labels with predictions: label={label}, '
f'prediction={prediction}\n\n'
'unable to pair labels, predictions, and example weights: '
f'label={label}, prediction={prediction}, '
f'example_weight={example_weight}\n\n'
'This is most likely a configuration error.')

for result in yield_results(label, prediction, example_weight):
Expand Down
47 changes: 47 additions & 0 deletions tensorflow_model_analysis/metrics/metric_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,53 @@ def testStandardMetricInputsWithMissingExampleWeightKeyRaisesError(self):
metric_util.to_label_prediction_example_weight(
example, output_name='output1'))

def testStandardMetricInputsWithNonScalarWeights(self):
example = metric_types.StandardMetricInputs(
label={'output_name': np.array([2])},
prediction={'output_name': np.array([0, 0.5, 0.3, 0.9])},
example_weight={'output_name': np.array([1.0, 0.0, 1.0, 1.0])})
iterable = metric_util.to_label_prediction_example_weight(
example, output_name='output_name', require_single_example_weight=False)

for expected_label, expected_prediction, expected_weight in zip(
(0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9), (1.0, 0.0, 1.0, 1.0)):
got_label, got_pred, got_example_weight = next(iterable)
self.assertAllClose(got_label, np.array([expected_label]))
self.assertAllEqual(got_pred, np.array([expected_prediction]))
self.assertAllClose(got_example_weight, np.array([expected_weight]))

def testStandardMetricInputsWithNonScalarWeightsNoFlatten(self):
example = metric_types.StandardMetricInputs(
label=np.array([2]),
prediction=np.array([0, 0.5, 0.3, 0.9]),
example_weight=np.array([1.0, 0.0, 1.0, 1.0]))
got_label, got_pred, got_example_weight = next(
metric_util.to_label_prediction_example_weight(
example, flatten=False, require_single_example_weight=False))
self.assertAllClose(got_label, np.array([2]))
self.assertAllEqual(got_pred, np.array([0, 0.5, 0.3, 0.9]))
self.assertAllClose(got_example_weight, np.array([1.0, 0.0, 1.0, 1.0]))

def testStandardMetricInputsWithMismatchedExampleWeightsRaisesError(self):
with self.assertRaises(ValueError):
example = metric_types.StandardMetricInputs(
labels=np.array([2]),
predictions=np.array([0, 0.5, 0.3, 0.9]),
example_weights=np.array([1.0, 0.0]))
next(
metric_util.to_label_prediction_example_weight(
example, flatten=True, require_single_example_weight=False))

def testStandardMetricInputsRequiringSingleExampleWeightRaisesError(self):
with self.assertRaises(ValueError):
example = metric_types.StandardMetricInputs(
labels=np.array([2]),
predictions=np.array([0, 0.5, 0.3, 0.9]),
example_weights=np.array([1.0, 0.0]))
next(
metric_util.to_label_prediction_example_weight(
example, require_single_example_weight=True))

def testPrepareLabelsAndPredictions(self):
labels = [0]
preds = {
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_model_analysis/metrics/min_label_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def add_input(
model_name=self._key.model_name,
output_name=self._key.output_name,
flatten=False,
allow_none=True)) # pytype: disable=wrong-arg-types
allow_none=True,
require_single_example_weight=True)) # pytype: disable=wrong-arg-types
if self._label_key:
labels = util.get_by_keys(element.features, [self._label_key])
if labels is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def add_input(self, accumulator: Matrices,
eval_config=self._eval_config,
model_name=self._key.model_name,
output_name=self._key.output_name,
flatten=False)) # pytype: disable=wrong-arg-types
flatten=False,
require_single_example_weight=True)) # pytype: disable=wrong-arg-types
if not label.shape:
raise ValueError(
'Label missing from example: StandardMetricInputs={}'.format(element))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def add_input(self, accumulator: _Matrices,
eval_config=self._eval_config,
model_name=self._key.model_name,
output_name=self._key.output_name,
flatten=False))
flatten=False,
require_single_example_weight=True))
if not labels.shape:
raise ValueError(
'Labels missing from example: StandardMetricInputs={}'.format(
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_model_analysis/metrics/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def _to_gains_example_weight(
eval_config=self._eval_config,
model_name=self._model_name,
output_name=self._output_name,
flatten=False)) # pytype: disable=wrong-arg-types
flatten=False,
require_single_example_weight=True)) # pytype: disable=wrong-arg-types
gains = util.get_by_keys(element.features, [self._gain_key])
if gains.size != predictions.size:
raise ValueError('expected {} to be same size as predictions {} != {}: '
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_model_analysis/metrics/tf_metric_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ def testMultiClassMetricsUsingConfusionMatrix(self, metric_name, top_k,
example4 = {
'labels': np.array([1]),
'predictions': np.array([0.3, 0.2, 0.05, 0.4, 0.05]),
'example_weights': np.array([0.3]),
# This tests that multi-dimensional weights are allowed.
'example_weights': np.array([0.3, 0.3, 0.3, 0.3, 0.3]),
}

with beam.Pipeline() as pipeline:
Expand Down

0 comments on commit d9ca830

Please sign in to comment.