From d186dd729c719f41ee6b8bd5f60197004ec0c617 Mon Sep 17 00:00:00 2001 From: embr Date: Fri, 11 Sep 2020 14:48:18 -0700 Subject: [PATCH] Add support for writing and reading metrics in parquet format. PiperOrigin-RevId: 331227203 --- RELEASE.md | 3 + .../slicer/slicer_lib.py | 21 ++ .../slicer/slicer_test.py | 60 +++-- .../metrics_plots_and_validations_writer.py | 222 +++++++++++++++--- ...trics_plots_and_validations_writer_test.py | 175 +++++++++++++- 5 files changed, 422 insertions(+), 59 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 583fa1d5d1..184ee507a2 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -4,6 +4,9 @@ ## Major Features and Improvements +* Added support for reading and writing metrics, plots and validation results + using Apache Parquet. + ## Bug fixes and other changes ## Breaking changes diff --git a/tensorflow_model_analysis/slicer/slicer_lib.py b/tensorflow_model_analysis/slicer/slicer_lib.py index 23d5b440e0..4fb6f96919 100644 --- a/tensorflow_model_analysis/slicer/slicer_lib.py +++ b/tensorflow_model_analysis/slicer/slicer_lib.py @@ -498,6 +498,27 @@ def _is_multi_dim_keys(slice_keys: SliceKeyType) -> bool: return False +def slice_key_matches_slice_specs( + slice_key: SliceKeyType, slice_specs: Iterable[SingleSliceSpec]) -> bool: + """Checks whether a slice key matches any slice spec. + + In this setting, a slice key matches a slice spec if it could have been + generated by that spec. + + Args: + slice_key: The slice key to check for applicability against slice specs. + slice_specs: Slice specs against which to check applicability of a slice + key. + + Returns: + True if the slice_key matches any slice specs, False otherwise. + """ + for slice_spec in slice_specs: + if slice_spec.is_slice_applicable(slice_key): + return True + return False + + @beam.typehints.with_input_types(types.Extracts) @beam.typehints.with_output_types(Tuple[SliceKeyType, types.Extracts]) class _FanoutSlicesDoFn(beam.DoFn): diff --git a/tensorflow_model_analysis/slicer/slicer_test.py b/tensorflow_model_analysis/slicer/slicer_test.py index fdc4df557a..48cc2454f3 100644 --- a/tensorflow_model_analysis/slicer/slicer_test.py +++ b/tensorflow_model_analysis/slicer/slicer_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +from absl.testing import parameterized import apache_beam as beam from apache_beam.testing import util import numpy as np @@ -75,7 +76,7 @@ def wrap_fpl(fpl): } -class SlicerTest(testutil.TensorflowModelAnalysisTest): +class SlicerTest(testutil.TensorflowModelAnalysisTest, parameterized.TestCase): def setUp(self): super(SlicerTest, self).setUp() @@ -431,7 +432,7 @@ def testSliceDefaultSlice(self): def check_result(got): try: - self.assertEqual(2, len(got), 'got: %s' % got) + self.assertLen(got, 2) expected_result = [ ((), wrap_fpl(fpls[0])), ((), wrap_fpl(fpls[1])), @@ -460,16 +461,14 @@ def testSliceOneSlice(self): def check_result(got): try: - self.assertEqual(4, len(got), 'got: %s' % got) + self.assertLen(got, 4) expected_result = [ ((), wrap_fpl(fpls[0])), ((), wrap_fpl(fpls[1])), ((('gender', 'f'),), wrap_fpl(fpls[0])), ((('gender', 'm'),), wrap_fpl(fpls[1])), ] - self.assertCountEqual( - sorted(got, key=lambda x: x[0]), - sorted(expected_result, key=lambda x: x[0])) + self.assertCountEqual(got, expected_result) except AssertionError as err: raise util.BeamAssertException(err) @@ -506,7 +505,7 @@ def testMultidimSlices(self): def check_result(got): try: - self.assertEqual(5, len(got), 'got: %s' % got) + self.assertLen(got, 5) del data[0][constants.SLICE_KEY_TYPES_KEY] del data[1][constants.SLICE_KEY_TYPES_KEY] expected_result = [ @@ -516,9 +515,7 @@ def check_result(got): ((('gender', 'f'),), data[1]), ((('gender', 'm'),), data[1]), ] - self.assertCountEqual( - sorted(got, key=lambda x: x[0]), - sorted(expected_result, key=lambda x: x[0])) + self.assertCountEqual(got, expected_result) except AssertionError as err: raise util.BeamAssertException(err) @@ -539,16 +536,14 @@ def testMultidimOverallSlices(self): def check_result(got): try: - self.assertEqual(2, len(got), 'got: %s' % got) + self.assertLen(got, 2) del data[0][constants.SLICE_KEY_TYPES_KEY] del data[1][constants.SLICE_KEY_TYPES_KEY] expected_result = [ ((), data[0]), ((), data[1]), ] - self.assertCountEqual( - sorted(got, key=lambda x: x[0]), - sorted(expected_result, key=lambda x: x[0])) + self.assertCountEqual(got, expected_result) except AssertionError as err: raise util.BeamAssertException(err) @@ -568,7 +563,7 @@ def testFilterOutSlices(self): def check_output(got): try: - self.assertEqual(2, len(got), 'got: %s' % got) + self.assertLen(got, 2) slices = {} for (k, v) in got: slices[k] = v @@ -590,6 +585,41 @@ def check_output(got): error_metric_key=metric_keys.ERROR_METRIC)) util.assert_that(output_dict, check_output) + @parameterized.named_parameters( + { + 'testcase_name': 'matching_single_spec', + 'slice_key': (('f1', 1),), + 'slice_specs': [slicer.SingleSliceSpec(features=[('f1', 1)])], + 'expected_result': True + }, + { + 'testcase_name': 'non_matching_single_spec', + 'slice_key': (('f1', 1),), + 'slice_specs': [slicer.SingleSliceSpec(columns=['f2'])], + 'expected_result': False + }, + { + 'testcase_name': 'matching_multiple_specs', + 'slice_key': (('f1', 1),), + 'slice_specs': [ + slicer.SingleSliceSpec(columns=['f1']), + slicer.SingleSliceSpec(columns=['f2']) + ], + 'expected_result': True + }, + { + 'testcase_name': 'empty_specs', + 'slice_key': (('f1', 1),), + 'slice_specs': [], + 'expected_result': False + }, + ) + def testSliceKeyMatchesSliceSpecs(self, slice_key, slice_specs, + expected_result): + self.assertEqual( + expected_result, + slicer.slice_key_matches_slice_specs(slice_key, slice_specs)) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py b/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py index 1e20b5e380..7cb2d074f9 100644 --- a/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py +++ b/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py @@ -19,13 +19,15 @@ # Standard __future__ imports from __future__ import print_function +import itertools import os -from typing import Any, Dict, Iterator, List, Optional, Text, Tuple +from typing import Any, Dict, Iterable, Iterator, List, Optional, Text, Tuple, Union from absl import logging import apache_beam as beam import numpy as np +import pyarrow as pa import six import tensorflow as tf from tensorflow_model_analysis import config @@ -42,14 +44,81 @@ from tensorflow_model_analysis.writers import writer +_PARQUET_FORMAT = 'parquet' +_TFRECORD_FORMAT = 'tfrecord' +_SUPPORTED_FORMATS = (_PARQUET_FORMAT, _TFRECORD_FORMAT) +_SLICE_KEY_PARQUET_COLUMN_NAME = 'slice_key' +_SERIALIZED_VALUE_PARQUET_COLUMN_NAME = 'serialized_value' +_SINGLE_SLICE_KEYS_PARQUET_FIELD_NAME = 'single_slice_specs' +_SLICE_KEY_ARROW_TYPE = pa.struct([(pa.field( + _SINGLE_SLICE_KEYS_PARQUET_FIELD_NAME, + pa.list_( + pa.struct([ + pa.field('column', pa.string()), + pa.field('bytes_value', pa.binary()), + pa.field('float_value', pa.float32()), + pa.field('int64_value', pa.int64()) + ]))))]) +_SLICED_PARQUET_SCHEMA = pa.schema([ + pa.field(_SLICE_KEY_PARQUET_COLUMN_NAME, _SLICE_KEY_ARROW_TYPE), + pa.field(_SERIALIZED_VALUE_PARQUET_COLUMN_NAME, pa.binary()) +]) +_UNSLICED_PARQUET_SCHEMA = pa.schema( + [pa.field(_SERIALIZED_VALUE_PARQUET_COLUMN_NAME, pa.binary())]) + +_SliceKeyDictPythonType = Dict[Text, List[Dict[Text, Union[bytes, float, int]]]] + + def _match_all_files(file_path: Text) -> Text: """Return expression to match all files at given path.""" return file_path + '*' +def _parquet_column_iterator(paths: Iterable[str], + column_name: str) -> Iterator[pa.Buffer]: + """Yields values from a bytes column in a set of parquet file partitions.""" + dataset = pa.parquet.ParquetDataset(paths) + table = dataset.read(columns=[column_name]) + for record_batch in table.to_batches(): + # always read index 0 because we filter to one column + value_array = record_batch.column(0) + for value in value_array: + yield value.as_buffer() + + +def _raw_value_iterator( + paths: Iterable[Text], + output_file_format: Text) -> Iterator[Union[pa.Buffer, bytes]]: + """Returns an iterator of raw per-record values from supported file formats. + + When reading parquet format files, values from the column with name + _SERIALIZED_VALUE_PARQUET_COLUMN_NAME will be read. + + Args: + paths: The paths from which to read records + output_file_format: The format of the files from which to read records. + + Returns: + An iterator which yields serialized values. + + Raises: + ValueError when the output_file_format is unknown. + """ + if output_file_format == _PARQUET_FORMAT: + return _parquet_column_iterator(paths, + _SERIALIZED_VALUE_PARQUET_COLUMN_NAME) + elif not output_file_format or output_file_format == _TFRECORD_FORMAT: + return itertools.chain(*(tf.compat.v1.python_io.tf_record_iterator(path) + for path in paths)) + raise ValueError('Formats "{}" are currently supported but got ' + 'output_file_format={}'.format(_SUPPORTED_FORMATS, + output_file_format)) + + def load_and_deserialize_metrics( output_path: Text, - output_file_format: Text = '' + output_file_format: Text = '', + slice_specs: Optional[Iterable[slicer.SingleSliceSpec]] = None ) -> Iterator[metrics_for_slice_pb2.MetricsForSlice]: """Read and deserialize the MetricsForSlice records. @@ -57,6 +126,9 @@ def load_and_deserialize_metrics( output_path: Path or pattern to search for metrics files under. If a directory is passed, files matching 'metrics*' will be searched for. output_file_format: Optional file extension to filter files by. + slice_specs: A set of SingleSliceSpecs to use for filtering returned + metrics. The metrics for a given slice key will be returned if that slice + key matches any of the slice_specs. Yields: MetricsForSlice protos found in matching files. @@ -66,14 +138,19 @@ def load_and_deserialize_metrics( pattern = _match_all_files(output_path) if output_file_format: pattern = pattern + '.' + output_file_format - for path in tf.io.gfile.glob(pattern): - for record in tf.compat.v1.python_io.tf_record_iterator(path): - yield metrics_for_slice_pb2.MetricsForSlice.FromString(record) + paths = tf.io.gfile.glob(pattern) + for value in _raw_value_iterator(paths, output_file_format): + metrics = metrics_for_slice_pb2.MetricsForSlice.FromString(value) + if slice_specs and not slicer.slice_key_matches_slice_specs( + slicer.deserialize_slice_key(metrics.slice_key), slice_specs): + continue + yield metrics def load_and_deserialize_plots( output_path: Text, - output_file_format: Text = '' + output_file_format: Text = '', + slice_specs: Optional[Iterable[slicer.SingleSliceSpec]] = None ) -> Iterator[metrics_for_slice_pb2.PlotsForSlice]: """Read and deserialize the PlotsForSlice records. @@ -81,6 +158,9 @@ def load_and_deserialize_plots( output_path: Path or pattern to search for plots files under. If a directory is passed, files matching 'plots*' will be searched for. output_file_format: Optional file extension to filter files by. + slice_specs: A set of SingleSliceSpecs to use for filtering returned plots. + The plots for a given slice key will be returned if that slice key matches + any of the slice_specs. Yields: PlotsForSlice protos found in matching files. @@ -90,9 +170,13 @@ def load_and_deserialize_plots( pattern = _match_all_files(output_path) if output_file_format: pattern = pattern + '.' + output_file_format - for path in tf.io.gfile.glob(pattern): - for record in tf.compat.v1.python_io.tf_record_iterator(path): - yield metrics_for_slice_pb2.PlotsForSlice.FromString(record) + paths = tf.io.gfile.glob(pattern) + for value in _raw_value_iterator(paths, output_file_format): + plots = metrics_for_slice_pb2.PlotsForSlice.FromString(value) + if slice_specs and not slicer.slice_key_matches_slice_specs( + slicer.deserialize_slice_key(plots.slice_key), slice_specs): + continue + yield plots def load_and_deserialize_validation_result( @@ -114,10 +198,10 @@ def load_and_deserialize_validation_result( if output_file_format: pattern = pattern + '.' + output_file_format validation_records = [] - for path in tf.io.gfile.glob(pattern): - for record in tf.compat.v1.python_io.tf_record_iterator(path): - validation_records.append( - validation_result_pb2.ValidationResult.FromString(record)) + paths = tf.io.gfile.glob(pattern) + for value in _raw_value_iterator(paths, output_file_format): + validation_records.append( + validation_result_pb2.ValidationResult.FromString(value)) assert len(validation_records) == 1 return validation_records[0] @@ -348,8 +432,15 @@ def MetricsPlotsAndValidationsWriter( # pylint: disable=invalid-name metrics_key: Name to use for metrics key in Evaluation output. plots_key: Name to use for plots key in Evaluation output. validations_key: Name to use for validations key in Evaluation output. - output_file_format: File format to use when saving files. Currently only - 'tfrecord' is supported. + output_file_format: File format to use when saving files. Currently + 'tfrecord' and 'parquet' are supported. If using parquet, the output + metrics and plots files will contain two columns: 'slice_key' and + 'serialized_value'. The 'slice_key' column will be a structured column + matching the metrics_for_slice_pb2.SliceKey proto. the 'serialized_value' + column will contain a serialized MetricsForSlice or PlotsForSlice + proto. The validation result file will contain a single column + 'serialized_value' which will contain a single serialized ValidationResult + proto. """ return writer.Writer( stage_name='WriteMetricsAndPlots', @@ -431,53 +522,110 @@ def _WriteMetricsPlotsAndValidations( # pylint: disable=invalid-name output_file_format: Text) -> beam.pvalue.PDone: """PTransform to write metrics and plots.""" - if output_file_format and output_file_format != 'tfrecord': - raise ValueError( - 'only "{}" format is currently supported: output_file_format={}'.format( - 'tfrecord', output_file_format)) - - if metrics_key in evaluation: + if output_file_format and output_file_format not in _SUPPORTED_FORMATS: + raise ValueError('only "{}" formats are currently supported but got ' + 'output_file_format={}'.format(_SUPPORTED_FORMATS, + output_file_format)) + + def convert_slice_key_to_parquet_dict( + slice_key: metrics_for_slice_pb2.SliceKey) -> _SliceKeyDictPythonType: + single_slice_key_dicts = [] + for single_slice_key in slice_key.single_slice_keys: + kind = single_slice_key.WhichOneof('kind') + if not kind: + continue + single_slice_key_dicts.append({kind: getattr(single_slice_key, kind)}) + return {_SINGLE_SLICE_KEYS_PARQUET_FIELD_NAME: single_slice_key_dicts} + + def convert_to_parquet_columns( + value: Union[metrics_for_slice_pb2.MetricsForSlice, + metrics_for_slice_pb2.PlotsForSlice] + ) -> Dict[Text, Union[_SliceKeyDictPythonType, bytes]]: + return { + _SLICE_KEY_PARQUET_COLUMN_NAME: + convert_slice_key_to_parquet_dict(value.slice_key), + _SERIALIZED_VALUE_PARQUET_COLUMN_NAME: + value.SerializeToString() + } + + if metrics_key in evaluation and constants.METRICS_KEY in output_paths: metrics = ( evaluation[metrics_key] | 'ConvertSliceMetricsToProto' >> beam.Map( convert_slice_metrics_to_proto, add_metrics_callbacks=add_metrics_callbacks)) - if constants.METRICS_KEY in output_paths: + file_path_prefix = output_paths[constants.METRICS_KEY] + if output_file_format == _PARQUET_FORMAT: + _ = ( + metrics + | 'ConvertToParquetColumns' >> beam.Map(convert_to_parquet_columns) + | 'WriteMetricsToParquet' >> beam.io.WriteToParquet( + file_path_prefix=file_path_prefix, + schema=_SLICED_PARQUET_SCHEMA, + file_name_suffix='.' + output_file_format)) + elif not output_file_format or output_file_format == _TFRECORD_FORMAT: _ = metrics | 'WriteMetrics' >> beam.io.WriteToTFRecord( - file_path_prefix=output_paths[constants.METRICS_KEY], + file_path_prefix=file_path_prefix, shard_name_template=None if output_file_format else '', file_name_suffix=('.' + output_file_format if output_file_format else ''), coder=beam.coders.ProtoCoder(metrics_for_slice_pb2.MetricsForSlice)) - if plots_key in evaluation: + if plots_key in evaluation and constants.PLOTS_KEY in output_paths: plots = ( evaluation[plots_key] | 'ConvertSlicePlotsToProto' >> beam.Map( convert_slice_plots_to_proto, add_metrics_callbacks=add_metrics_callbacks)) - if constants.PLOTS_KEY in output_paths: - _ = plots | 'WritePlots' >> beam.io.WriteToTFRecord( - file_path_prefix=output_paths[constants.PLOTS_KEY], + file_path_prefix = output_paths[constants.PLOTS_KEY] + if output_file_format == _PARQUET_FORMAT: + _ = ( + plots + | + 'ConvertPlotsToParquetColumns' >> beam.Map(convert_to_parquet_columns) + | 'WritePlotsToParquet' >> beam.io.WriteToParquet( + file_path_prefix=file_path_prefix, + schema=_SLICED_PARQUET_SCHEMA, + file_name_suffix='.' + output_file_format)) + elif not output_file_format or output_file_format == _TFRECORD_FORMAT: + _ = plots | 'WritePlotsToTFRecord' >> beam.io.WriteToTFRecord( + file_path_prefix=file_path_prefix, shard_name_template=None if output_file_format else '', file_name_suffix=('.' + output_file_format if output_file_format else ''), coder=beam.coders.ProtoCoder(metrics_for_slice_pb2.PlotsForSlice)) - if validations_key in evaluation: + if (validations_key in evaluation and + constants.VALIDATIONS_KEY in output_paths): validations = ( evaluation[validations_key] | 'MergeValidationResults' >> beam.CombineGlobally( _CombineValidations(eval_config))) - if constants.VALIDATIONS_KEY in output_paths: - # We only use a single shard here because validations are usually single - # values. - _ = validations | 'WriteValidations' >> beam.io.WriteToTFRecord( - file_path_prefix=output_paths[constants.VALIDATIONS_KEY], - shard_name_template='', - file_name_suffix=('.' + - output_file_format if output_file_format else ''), - coder=beam.coders.ProtoCoder(validation_result_pb2.ValidationResult)) + file_path_prefix = output_paths[constants.VALIDATIONS_KEY] + # We only use a single shard here because validations are usually single + # values. Setting the shard_name_template to the empty string forces this. + shard_name_template = '' + if output_file_format == _PARQUET_FORMAT: + _ = ( + validations + | 'ConvertValidationsToParquetColumns' >> beam.Map( + lambda v: # pylint: disable=g-long-lambda + {_SERIALIZED_VALUE_PARQUET_COLUMN_NAME: v.SerializeToString()}) + | 'WriteValidationsToParquet' >> beam.io.WriteToParquet( + file_path_prefix=file_path_prefix, + shard_name_template=shard_name_template, + schema=_UNSLICED_PARQUET_SCHEMA, + file_name_suffix='.' + output_file_format)) + elif not output_file_format or output_file_format == _TFRECORD_FORMAT: + _ = ( + validations + | 'WriteValidationsToTFRecord' >> beam.io.WriteToTFRecord( + file_path_prefix=file_path_prefix, + shard_name_template=shard_name_template, + file_name_suffix=('.' + output_file_format + if output_file_format else ''), + coder=beam.coders.ProtoCoder( + validation_result_pb2.ValidationResult))) return beam.pvalue.PDone(list(evaluation.values())[0].pipeline) diff --git a/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer_test.py b/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer_test.py index 2fd2658999..b9668710f4 100644 --- a/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer_test.py +++ b/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer_test.py @@ -895,8 +895,11 @@ def testConvertSliceMetricsToProtoTensorValuedMetrics(self): (slice_key, slice_metrics), []) self.assertProtoEquals(expected_metrics_for_slice, got) - @parameterized.named_parameters(('without_output_file_format', ''), - ('with_output_file_format', 'tfrecord')) + _OUTPUT_FORMAT_PARAMS = [('without_output_file_format', ''), + ('tfrecord_file_format', 'tfrecord'), + ('parquet_file_format', 'parquet')] + + @parameterized.named_parameters(_OUTPUT_FORMAT_PARAMS) def testWriteValidationResults(self, output_file_format): model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir() eval_shared_model = self._build_keras_model(model_dir, mul=0) @@ -1067,7 +1070,7 @@ def testWriteValidationResults(self, output_file_format): validation_result = ( metrics_plots_and_validations_writer .load_and_deserialize_validation_result( - os.path.dirname(validations_file))) + os.path.dirname(validations_file), output_file_format)) expected_validations = [ text_format.Parse( @@ -1158,8 +1161,7 @@ def testWriteValidationResults(self, output_file_format): self.assertCountEqual(expected_slicing_details, validation_result.validation_details.slicing_details) - @parameterized.named_parameters(('without_output_file_format', ''), - ('with_output_file_format', 'tfrecord')) + @parameterized.named_parameters(_OUTPUT_FORMAT_PARAMS) def testWriteMetricsAndPlots(self, output_file_format): metrics_file = os.path.join(self._getTempDir(), 'metrics') plots_file = os.path.join(self._getTempDir(), 'plots') @@ -1241,7 +1243,7 @@ def testWriteMetricsAndPlots(self, output_file_format): metric_records = list( metrics_plots_and_validations_writer.load_and_deserialize_metrics( - metrics_file)) + metrics_file, output_file_format)) self.assertLen(metric_records, 1, 'metrics: %s' % metric_records) self.assertProtoEquals(expected_metrics_for_slice, metric_records[0]) @@ -1296,7 +1298,166 @@ def testWriteMetricsAndPlots(self, output_file_format): plot_records = list( metrics_plots_and_validations_writer.load_and_deserialize_plots( - plots_file)) + plots_file, output_file_format)) + self.assertLen(plot_records, 1, 'plots: %s' % plot_records) + self.assertProtoEquals(expected_plots_for_slice, plot_records[0]) + + @parameterized.named_parameters(('parquet_file_format', 'parquet')) + def testLoadAndDeserializeFilteredMetricsAndPlots(self, output_file_format): + metrics_file = os.path.join(self._getTempDir(), 'metrics') + plots_file = os.path.join(self._getTempDir(), 'plots') + temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir') + + _, eval_export_dir = ( + fixed_prediction_estimator.simple_fixed_prediction_estimator( + None, temp_eval_export_dir)) + eval_config = config.EvalConfig( + model_specs=[config.ModelSpec()], + slicing_specs=[ + config.SlicingSpec(), + config.SlicingSpec(feature_keys=['prediction']) + ], + options=config.Options( + disabled_outputs={'values': ['eval_config.json']})) + eval_shared_model = self.createTestEvalSharedModel( + eval_saved_model_path=eval_export_dir, + add_metrics_callbacks=[ + post_export_metrics.example_count(), + post_export_metrics.calibration_plot_and_prediction_histogram( + num_buckets=2) + ]) + extractors = [ + predict_extractor.PredictExtractor(eval_shared_model), + slice_key_extractor.SliceKeyExtractor( + eval_config=eval_config, materialize=False) + ] + evaluators = [ + metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(eval_shared_model) + ] + output_paths = { + constants.METRICS_KEY: metrics_file, + constants.PLOTS_KEY: plots_file + } + writers = [ + metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter( + output_paths, + eval_config=eval_config, + add_metrics_callbacks=eval_shared_model.add_metrics_callbacks, + output_file_format=output_file_format) + ] + + with beam.Pipeline() as pipeline: + example1 = self._makeExample(prediction=0.0, label=1.0, country='US') + example2 = self._makeExample(prediction=1.0, label=1.0, country='CA') + + # pylint: disable=no-value-for-parameter + _ = ( + pipeline + | 'Create' >> beam.Create([ + example1.SerializeToString(), + example2.SerializeToString(), + ]) + | 'ExtractEvaluateAndWriteResults' >> + model_eval_lib.ExtractEvaluateAndWriteResults( + eval_config=eval_config, + eval_shared_model=eval_shared_model, + extractors=extractors, + evaluators=evaluators, + writers=writers)) + # pylint: enable=no-value-for-parameter + + # only read the metrics with slice keys that match the following spec + slice_keys_filter = [slicer.SingleSliceSpec(features=[('prediction', 0)])] + + expected_metrics_for_slice = text_format.Parse( + """ + slice_key { + single_slice_keys { + column: "prediction" + float_value: 0 + } + } + metrics { + key: "average_loss" + value { + double_value { + value: 1.0 + } + } + } + metrics { + key: "post_export_metrics/example_count" + value { + double_value { + value: 1.0 + } + } + } + """, metrics_for_slice_pb2.MetricsForSlice()) + + metric_records = list( + metrics_plots_and_validations_writer.load_and_deserialize_metrics( + metrics_file, output_file_format, slice_keys_filter)) + self.assertLen(metric_records, 1, 'metrics: %s' % metric_records) + self.assertProtoEquals(expected_metrics_for_slice, metric_records[0]) + + expected_plots_for_slice = text_format.Parse( + """ + slice_key { + single_slice_keys { + column: "prediction" + float_value: 0 + } + } + plots { + key: "post_export_metrics" + value { + calibration_histogram_buckets { + buckets { + lower_threshold_inclusive: -inf + num_weighted_examples {} + total_weighted_label {} + total_weighted_refined_prediction {} + } + buckets { + upper_threshold_exclusive: 0.5 + num_weighted_examples { + value: 1.0 + } + total_weighted_label { + value: 1.0 + } + total_weighted_refined_prediction {} + } + buckets { + lower_threshold_inclusive: 0.5 + upper_threshold_exclusive: 1.0 + num_weighted_examples { + } + total_weighted_label {} + total_weighted_refined_prediction {} + } + buckets { + lower_threshold_inclusive: 1.0 + upper_threshold_exclusive: inf + num_weighted_examples { + value: 0.0 + } + total_weighted_label { + value: 0.0 + } + total_weighted_refined_prediction { + value: 0.0 + } + } + } + } + } + """, metrics_for_slice_pb2.PlotsForSlice()) + + plot_records = list( + metrics_plots_and_validations_writer.load_and_deserialize_plots( + plots_file, output_file_format, slice_keys_filter)) self.assertLen(plot_records, 1, 'plots: %s' % plot_records) self.assertProtoEquals(expected_plots_for_slice, plot_records[0])