Skip to content

Commit

Permalink
Add support for writing and reading metrics in parquet format.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 331227203
  • Loading branch information
embr authored and tf-model-analysis-team committed Sep 11, 2020
1 parent 752b1ef commit d186dd7
Show file tree
Hide file tree
Showing 5 changed files with 422 additions and 59 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

## Major Features and Improvements

* Added support for reading and writing metrics, plots and validation results
using Apache Parquet.

## Bug fixes and other changes

## Breaking changes
Expand Down
21 changes: 21 additions & 0 deletions tensorflow_model_analysis/slicer/slicer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,27 @@ def _is_multi_dim_keys(slice_keys: SliceKeyType) -> bool:
return False


def slice_key_matches_slice_specs(
slice_key: SliceKeyType, slice_specs: Iterable[SingleSliceSpec]) -> bool:
"""Checks whether a slice key matches any slice spec.
In this setting, a slice key matches a slice spec if it could have been
generated by that spec.
Args:
slice_key: The slice key to check for applicability against slice specs.
slice_specs: Slice specs against which to check applicability of a slice
key.
Returns:
True if the slice_key matches any slice specs, False otherwise.
"""
for slice_spec in slice_specs:
if slice_spec.is_slice_applicable(slice_key):
return True
return False


@beam.typehints.with_input_types(types.Extracts)
@beam.typehints.with_output_types(Tuple[SliceKeyType, types.Extracts])
class _FanoutSlicesDoFn(beam.DoFn):
Expand Down
60 changes: 45 additions & 15 deletions tensorflow_model_analysis/slicer/slicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized
import apache_beam as beam
from apache_beam.testing import util
import numpy as np
Expand Down Expand Up @@ -75,7 +76,7 @@ def wrap_fpl(fpl):
}


class SlicerTest(testutil.TensorflowModelAnalysisTest):
class SlicerTest(testutil.TensorflowModelAnalysisTest, parameterized.TestCase):

def setUp(self):
super(SlicerTest, self).setUp()
Expand Down Expand Up @@ -431,7 +432,7 @@ def testSliceDefaultSlice(self):

def check_result(got):
try:
self.assertEqual(2, len(got), 'got: %s' % got)
self.assertLen(got, 2)
expected_result = [
((), wrap_fpl(fpls[0])),
((), wrap_fpl(fpls[1])),
Expand Down Expand Up @@ -460,16 +461,14 @@ def testSliceOneSlice(self):

def check_result(got):
try:
self.assertEqual(4, len(got), 'got: %s' % got)
self.assertLen(got, 4)
expected_result = [
((), wrap_fpl(fpls[0])),
((), wrap_fpl(fpls[1])),
((('gender', 'f'),), wrap_fpl(fpls[0])),
((('gender', 'm'),), wrap_fpl(fpls[1])),
]
self.assertCountEqual(
sorted(got, key=lambda x: x[0]),
sorted(expected_result, key=lambda x: x[0]))
self.assertCountEqual(got, expected_result)
except AssertionError as err:
raise util.BeamAssertException(err)

Expand Down Expand Up @@ -506,7 +505,7 @@ def testMultidimSlices(self):

def check_result(got):
try:
self.assertEqual(5, len(got), 'got: %s' % got)
self.assertLen(got, 5)
del data[0][constants.SLICE_KEY_TYPES_KEY]
del data[1][constants.SLICE_KEY_TYPES_KEY]
expected_result = [
Expand All @@ -516,9 +515,7 @@ def check_result(got):
((('gender', 'f'),), data[1]),
((('gender', 'm'),), data[1]),
]
self.assertCountEqual(
sorted(got, key=lambda x: x[0]),
sorted(expected_result, key=lambda x: x[0]))
self.assertCountEqual(got, expected_result)
except AssertionError as err:
raise util.BeamAssertException(err)

Expand All @@ -539,16 +536,14 @@ def testMultidimOverallSlices(self):

def check_result(got):
try:
self.assertEqual(2, len(got), 'got: %s' % got)
self.assertLen(got, 2)
del data[0][constants.SLICE_KEY_TYPES_KEY]
del data[1][constants.SLICE_KEY_TYPES_KEY]
expected_result = [
((), data[0]),
((), data[1]),
]
self.assertCountEqual(
sorted(got, key=lambda x: x[0]),
sorted(expected_result, key=lambda x: x[0]))
self.assertCountEqual(got, expected_result)
except AssertionError as err:
raise util.BeamAssertException(err)

Expand All @@ -568,7 +563,7 @@ def testFilterOutSlices(self):

def check_output(got):
try:
self.assertEqual(2, len(got), 'got: %s' % got)
self.assertLen(got, 2)
slices = {}
for (k, v) in got:
slices[k] = v
Expand All @@ -590,6 +585,41 @@ def check_output(got):
error_metric_key=metric_keys.ERROR_METRIC))
util.assert_that(output_dict, check_output)

@parameterized.named_parameters(
{
'testcase_name': 'matching_single_spec',
'slice_key': (('f1', 1),),
'slice_specs': [slicer.SingleSliceSpec(features=[('f1', 1)])],
'expected_result': True
},
{
'testcase_name': 'non_matching_single_spec',
'slice_key': (('f1', 1),),
'slice_specs': [slicer.SingleSliceSpec(columns=['f2'])],
'expected_result': False
},
{
'testcase_name': 'matching_multiple_specs',
'slice_key': (('f1', 1),),
'slice_specs': [
slicer.SingleSliceSpec(columns=['f1']),
slicer.SingleSliceSpec(columns=['f2'])
],
'expected_result': True
},
{
'testcase_name': 'empty_specs',
'slice_key': (('f1', 1),),
'slice_specs': [],
'expected_result': False
},
)
def testSliceKeyMatchesSliceSpecs(self, slice_key, slice_specs,
expected_result):
self.assertEqual(
expected_result,
slicer.slice_key_matches_slice_specs(slice_key, slice_specs))


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit d186dd7

Please sign in to comment.