Skip to content

Commit

Permalink
Update query_statistics to support weighted examples.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 407219066
  • Loading branch information
mdreves authored and tfx-copybara committed Nov 3, 2021
1 parent 4389611 commit 4180df6
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 42 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

## Bug fixes and other Changes

* Updated QueryStatistics to support weighted examples.

## Breaking Changes

## Deprecations
Expand Down
94 changes: 63 additions & 31 deletions tensorflow_model_analysis/metrics/query_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
# Standard __future__ imports
from __future__ import print_function

from typing import Dict, Iterable, Text
from typing import Dict, Iterable, Optional, Text

import apache_beam as beam
from tensorflow_model_analysis.metrics import metric_types
from tensorflow_model_analysis.metrics import metric_util
from tensorflow_model_analysis.proto import config_pb2

TOTAL_QUERIES_NAME = 'total_queries'
TOTAL_DOCUMENTS_NAME = 'total_documents'
Expand All @@ -38,10 +40,10 @@ class QueryStatistics(metric_types.Metric):
"""

def __init__(self,
total_queries_name=TOTAL_QUERIES_NAME,
total_documents_name=TOTAL_DOCUMENTS_NAME,
min_documents_name=MIN_DOCUMENTS_NAME,
max_documents_name=MAX_DOCUMENTS_NAME):
total_queries_name: Text = TOTAL_QUERIES_NAME,
total_documents_name: Text = TOTAL_DOCUMENTS_NAME,
min_documents_name: Text = MIN_DOCUMENTS_NAME,
max_documents_name: Text = MAX_DOCUMENTS_NAME):
"""Initializes query statistics metrics.
Args:
Expand All @@ -62,24 +64,39 @@ def __init__(self,


def _query_statistics(
total_queries_name=TOTAL_QUERIES_NAME,
total_documents_name=TOTAL_DOCUMENTS_NAME,
min_documents_name=MIN_DOCUMENTS_NAME,
max_documents_name=MAX_DOCUMENTS_NAME,
total_queries_name: Text = TOTAL_QUERIES_NAME,
total_documents_name: Text = TOTAL_DOCUMENTS_NAME,
min_documents_name: Text = MIN_DOCUMENTS_NAME,
max_documents_name: Text = MAX_DOCUMENTS_NAME,
eval_config: Optional[config_pb2.EvalConfig] = None,
model_name: Text = '',
output_name: Text = '',
query_key: Text = '',
example_weighted: bool = False) -> metric_types.MetricComputations:
"""Returns metric computations for query statistics."""
if not query_key:
raise ValueError('a query_key is required to use QueryStatistics metrics')
if example_weighted:
raise NotImplementedError(
'QueryStatistics cannot be used with weighted metrics. It can only be '
'used with metrics_spec.example_weights.unweighted set to true')

total_queries_key = metric_types.MetricKey(name=total_queries_name)
total_documents_key = metric_types.MetricKey(name=total_documents_name)
min_documents_key = metric_types.MetricKey(name=min_documents_name)
max_documents_key = metric_types.MetricKey(name=max_documents_name)
total_queries_key = metric_types.MetricKey(
name=total_queries_name,
model_name=model_name,
output_name=output_name,
example_weighted=example_weighted)
total_documents_key = metric_types.MetricKey(
name=total_documents_name,
model_name=model_name,
output_name=output_name,
example_weighted=example_weighted)
min_documents_key = metric_types.MetricKey(
name=min_documents_name,
model_name=model_name,
output_name=output_name,
example_weighted=example_weighted)
max_documents_key = metric_types.MetricKey(
name=max_documents_name,
model_name=model_name,
output_name=output_name,
example_weighted=example_weighted)

return [
metric_types.MetricComputation(
Expand All @@ -91,7 +108,9 @@ def _query_statistics(
combiner=_QueryStatisticsCombiner(total_queries_key,
total_documents_key,
min_documents_key,
max_documents_key))
max_documents_key, eval_config,
model_name, output_name,
example_weighted))
]


Expand All @@ -101,13 +120,11 @@ class _QueryStatisticsAccumulator(object):
'total_queries', 'total_documents', 'min_documents', 'max_documents'
]

LARGE_INT = 1000000000

def __init__(self):
self.total_queries = 0
self.total_documents = 0
self.min_documents = self.LARGE_INT
self.max_documents = 0
self.total_queries = 0.0
self.total_documents = 0.0
self.min_documents = float('inf')
self.max_documents = 0.0


class _QueryStatisticsCombiner(beam.CombineFn):
Expand All @@ -116,11 +133,17 @@ class _QueryStatisticsCombiner(beam.CombineFn):
def __init__(self, total_queries_key: metric_types.MetricKey,
total_documents_key: metric_types.MetricKey,
min_documents_key: metric_types.MetricKey,
max_documents_key: metric_types.MetricKey):
max_documents_key: metric_types.MetricKey,
eval_config: config_pb2.EvalConfig, model_name: Text,
output_name: Text, example_weighted: bool):
self._total_queries_key = total_queries_key
self._total_documents_key = total_documents_key
self._min_documents_key = min_documents_key
self._max_documents_key = max_documents_key
self._eval_config = eval_config
self._model_name = model_name
self._output_name = output_name
self._example_weighted = example_weighted

def create_accumulator(self) -> _QueryStatisticsAccumulator:
return _QueryStatisticsAccumulator()
Expand All @@ -129,11 +152,20 @@ def add_input(
self, accumulator: _QueryStatisticsAccumulator,
element: metric_types.StandardMetricInputs
) -> _QueryStatisticsAccumulator:
accumulator.total_queries += 1
num_documents = len(element.prediction)
accumulator.total_documents += num_documents
accumulator.min_documents = min(accumulator.min_documents, num_documents)
accumulator.max_documents = max(accumulator.max_documents, num_documents)
for _, _, example_weight in (metric_util.to_label_prediction_example_weight(
element,
eval_config=self._eval_config,
model_name=self._model_name,
output_name=self._output_name,
example_weighted=self._example_weighted,
flatten=False,
require_single_example_weight=True)):
example_weight = float(example_weight)
accumulator.total_queries += example_weight
num_documents = len(element.prediction) * example_weight
accumulator.total_documents += num_documents
accumulator.min_documents = min(accumulator.min_documents, num_documents)
accumulator.max_documents = max(accumulator.max_documents, num_documents)
return accumulator

def merge_accumulators(
Expand All @@ -152,7 +184,7 @@ def merge_accumulators(

def extract_output(
self, accumulator: _QueryStatisticsAccumulator
) -> Dict[metric_types.MetricKey, int]:
) -> Dict[metric_types.MetricKey, float]:
return {
self._total_queries_key: accumulator.total_queries,
self._total_documents_key: accumulator.total_documents,
Expand Down
32 changes: 21 additions & 11 deletions tensorflow_model_analysis/metrics/query_statistics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# Standard __future__ imports
from __future__ import print_function

from absl.testing import parameterized
import apache_beam as beam
from apache_beam.testing import util
import numpy as np
Expand All @@ -30,11 +31,16 @@
from tensorflow_model_analysis.utils import util as tfma_util


class QueryStatisticsTest(testutil.TensorflowModelAnalysisTest):
class QueryStatisticsTest(testutil.TensorflowModelAnalysisTest,
parameterized.TestCase):

def testQueryStatistics(self):
@parameterized.named_parameters(('weighted', True, 1.0 + 2.0 + 3.0,
1 * 2.0 + 3 * 2.0 + 1 * 3.0, 2.0, 3 * 2.0),
('unweighted', False, 3.0, 6.0, 1.0, 3.0))
def testQueryStatistics(self, example_weighted, total_queries,
total_documents, min_documents, max_documents):
metrics = query_statistics.QueryStatistics().computations(
query_key='query', example_weighted=False)[0]
query_key='query', example_weighted=example_weighted)[0]

query1_example1 = {
'labels': np.array([1.0]),
Expand Down Expand Up @@ -113,16 +119,20 @@ def check_result(got):
self.assertLen(got, 1)
got_slice_key, got_metrics = got[0]
self.assertEqual(got_slice_key, ())
total_queries_key = metric_types.MetricKey(name='total_queries')
total_documents_key = metric_types.MetricKey(name='total_documents')
min_documents_key = metric_types.MetricKey(name='min_documents')
max_documents_key = metric_types.MetricKey(name='max_documents')
total_queries_key = metric_types.MetricKey(
name='total_queries', example_weighted=example_weighted)
total_documents_key = metric_types.MetricKey(
name='total_documents', example_weighted=example_weighted)
min_documents_key = metric_types.MetricKey(
name='min_documents', example_weighted=example_weighted)
max_documents_key = metric_types.MetricKey(
name='max_documents', example_weighted=example_weighted)
self.assertDictElementsAlmostEqual(
got_metrics, {
total_queries_key: 3,
total_documents_key: 6,
min_documents_key: 1,
max_documents_key: 3
total_queries_key: total_queries,
total_documents_key: total_documents,
min_documents_key: min_documents,
max_documents_key: max_documents
},
places=5)

Expand Down

0 comments on commit 4180df6

Please sign in to comment.