Skip to content

Commit

Permalink
Separate the PredictionsExtractor into two extractors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 462483630
  • Loading branch information
tf-model-analysis-team authored and tfx-copybara committed Jul 21, 2022
1 parent eaed36b commit 47a6bda
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 151 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

## Bug fixes and other Changes

* Separate the PredictionsExtractor into two extractors.

## Breaking Changes

## Deprecations
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_model_analysis/api/model_eval_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tensorflow_model_analysis.extractors import features_extractor
from tensorflow_model_analysis.extractors import labels_extractor
from tensorflow_model_analysis.extractors import legacy_predict_extractor
from tensorflow_model_analysis.extractors import materialized_predictions_extractor
from tensorflow_model_analysis.extractors import predictions_extractor
from tensorflow_model_analysis.extractors import slice_key_extractor
from tensorflow_model_analysis.extractors import sql_slice_key_extractor
Expand Down Expand Up @@ -633,7 +634,8 @@ def default_extractors( # pylint: disable=invalid-name
labels_extractor.LabelsExtractor(eval_config=eval_config),
example_weights_extractor.ExampleWeightsExtractor(
eval_config=eval_config),
predictions_extractor.PredictionsExtractor(eval_config=eval_config)
materialized_predictions_extractor.MaterializedPredictionsExtractor(
eval_config),
] + slicing_extractors


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tensorflow_model_analysis.extractors import features_extractor
from tensorflow_model_analysis.extractors import labels_extractor
from tensorflow_model_analysis.extractors import legacy_predict_extractor
from tensorflow_model_analysis.extractors import materialized_predictions_extractor
from tensorflow_model_analysis.extractors import predictions_extractor
from tensorflow_model_analysis.extractors import slice_key_extractor
from tensorflow_model_analysis.extractors import unbatch_extractor
Expand Down Expand Up @@ -2784,7 +2785,8 @@ def testMetricsSpecsCountersInModelAgnosticMode(self):
features_extractor.FeaturesExtractor(eval_config),
labels_extractor.LabelsExtractor(eval_config),
example_weights_extractor.ExampleWeightsExtractor(eval_config),
predictions_extractor.PredictionsExtractor(eval_config),
materialized_predictions_extractor.MaterializedPredictionsExtractor(
eval_config),
unbatch_extractor.UnbatchExtractor(),
slice_key_extractor.SliceKeyExtractor(eval_config=eval_config)
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batched materialized predictions extractor."""

import copy

import apache_beam as beam
from tensorflow_model_analysis import constants
from tensorflow_model_analysis import types
from tensorflow_model_analysis.extractors import extractor
from tensorflow_model_analysis.proto import config_pb2
from tensorflow_model_analysis.utils import model_util

_MATERIALIZED_PREDICTIONS_EXTRACTOR_STAGE_NAME = 'ExtractMaterializedPredictions'


def MaterializedPredictionsExtractor(
eval_config: config_pb2.EvalConfig) -> extractor.Extractor:
"""Creates an extractor for rekeying preexisting predictions.
The extractor's PTransform uses the config's ModelSpec.prediction_key(s)
to lookup the associated prediction values stored as features under the
tfma.FEATURES_KEY in extracts. The resulting values are then added to the
extracts under the key tfma.PREDICTIONS_KEY.
Args:
eval_config: Eval config.
Returns:
Extractor for rekeying preexisting predictions.
"""
# pylint: disable=no-value-for-parameter
return extractor.Extractor(
stage_name=_MATERIALIZED_PREDICTIONS_EXTRACTOR_STAGE_NAME,
ptransform=_ExtractMaterializedPredictions(eval_config=eval_config))


@beam.ptransform_fn
@beam.typehints.with_input_types(types.Extracts)
@beam.typehints.with_output_types(types.Extracts)
def _ExtractMaterializedPredictions( # pylint: disable=invalid-name
extracts: beam.pvalue.PCollection,
eval_config: config_pb2.EvalConfig) -> beam.pvalue.PCollection:
"""A PTransform that populates the predictions key in the extracts.
Args:
extracts: PCollection of extracts containing model inputs keyed by
tfma.FEATURES_KEY (if model inputs are named) or tfma.INPUTS_KEY (if model
takes raw tf.Examples as input).
eval_config: Eval config.
Returns:
PCollection of Extracts updated with the predictions.
"""

def rekey_predictions( # pylint: disable=invalid-name
batched_extracts: types.Extracts) -> types.Extracts:
"""Extract predictions from extracts containing features."""
result = copy.copy(batched_extracts)
predictions = model_util.get_feature_values_for_model_spec_field(
list(eval_config.model_specs), 'prediction_key', 'prediction_keys',
result)
if predictions is not None:
result[constants.PREDICTIONS_KEY] = predictions
return result

return extracts | 'RekeyPredictions' >> beam.Map(rekey_predictions)
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for batched materialized predictions extractor."""

import apache_beam as beam
from apache_beam.testing import util
import numpy as np
import tensorflow as tf
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.api import model_eval_lib
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.extractors import features_extractor
from tensorflow_model_analysis.extractors import materialized_predictions_extractor
from tensorflow_model_analysis.proto import config_pb2
from tfx_bsl.tfxio import tensor_adapter
from tfx_bsl.tfxio import test_util

from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import schema_pb2


class MaterializedPredictionsExtractorTest(
testutil.TensorflowModelAnalysisTest):

def test_rekey_predictions_in_features(self):
model_spec1 = config_pb2.ModelSpec(
name='model1', prediction_key='prediction')
model_spec2 = config_pb2.ModelSpec(
name='model2',
prediction_keys={
'output1': 'prediction1',
'output2': 'prediction2'
})
eval_config = config_pb2.EvalConfig(model_specs=[model_spec1, model_spec2])
schema = text_format.Parse(
"""
tensor_representation_group {
key: ""
value {
tensor_representation {
key: "fixed_int"
value {
dense_tensor {
column_name: "fixed_int"
}
}
}
}
}
feature {
name: "prediction"
type: FLOAT
}
feature {
name: "prediction1"
type: FLOAT
}
feature {
name: "prediction2"
type: FLOAT
}
feature {
name: "fixed_int"
type: INT
}
""", schema_pb2.Schema())
tfx_io = test_util.InMemoryTFExampleRecord(
schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)
tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
arrow_schema=tfx_io.ArrowSchema(),
tensor_representations=tfx_io.TensorRepresentations())
feature_extractor = features_extractor.FeaturesExtractor(
eval_config=eval_config,
tensor_representations=tensor_adapter_config.tensor_representations)
prediction_extractor = materialized_predictions_extractor.MaterializedPredictionsExtractor(
eval_config)

examples = [
self._makeExample(
prediction=1.0, prediction1=1.0, prediction2=0.0, fixed_int=1),
self._makeExample(
prediction=1.0, prediction1=1.0, prediction2=1.0, fixed_int=1)
]

with beam.Pipeline() as pipeline:
# pylint: disable=no-value-for-parameter
result = (
pipeline
| 'Create' >> beam.Create([e.SerializeToString() for e in examples],
reshuffle=False)
| 'BatchExamples' >> tfx_io.BeamSource(batch_size=2)
| 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
| feature_extractor.stage_name >> feature_extractor.ptransform
| prediction_extractor.stage_name >> prediction_extractor.ptransform)

# pylint: enable=no-value-for-parameter

def check_result(got):
try:
self.assertLen(got, 1)
for model_name in ('model1', 'model2'):
self.assertIn(model_name, got[0][constants.PREDICTIONS_KEY])
self.assertAllClose(got[0][constants.PREDICTIONS_KEY]['model1'],
np.array([1.0, 1.0]))
self.assertAllClose(got[0][constants.PREDICTIONS_KEY]['model2'], {
'output1': np.array([1.0, 1.0]),
'output2': np.array([0.0, 1.0])
})

except AssertionError as err:
raise util.BeamAssertException(err)

util.assert_that(result, check_result, label='result')


if __name__ == '__main__':
tf.compat.v1.enable_v2_behavior()
tf.test.main()
80 changes: 26 additions & 54 deletions tensorflow_model_analysis/extractors/predictions_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batched predict extractor."""
"""Batched predictions extractor."""

import copy

from typing import Dict, Optional
from typing import Dict

import apache_beam as beam
from tensorflow_model_analysis import constants
Expand All @@ -29,31 +27,19 @@

def PredictionsExtractor(
eval_config: config_pb2.EvalConfig,
eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels] = None
eval_shared_model: types.MaybeMultipleEvalSharedModels
) -> extractor.Extractor:
"""Creates an extractor for performing predictions over a batch.
The extractor runs in two modes:
1) If one or more EvalSharedModels are provided
The extractor's PTransform loads and runs the serving saved_model(s) against
every extract yielding a copy of the incoming extracts with an additional
extract added for the predictions keyed by tfma.PREDICTIONS_KEY. The model
inputs are searched for under tfma.FEATURES_KEY (keras only) or tfma.INPUT_KEY
(if tfma.FEATURES_KEY is not set or the model is non-keras). If multiple
models are used the predictions will be stored in a dict keyed by model name.
2) If no EvalSharedModels are provided
The extractor's PTransform uses the config's ModelSpec.prediction_key(s)
to lookup the associated prediction values stored as features under the
tfma.FEATURES_KEY in extracts. The resulting values are then added to the
extracts under the key tfma.PREDICTIONS_KEY.
Note that the use of a prediction_key in the ModelSpecs serve two use cases:
(a) as a key into the dict of predictions output (option 1)
(b) as the key for a pre-computed prediction stored as a feature (option 2)
Note that the prediction_key in the ModelSpecs also serves as a key into the
dict of the prediction's output.
Args:
eval_config: Eval config.
Expand All @@ -66,61 +52,47 @@ def PredictionsExtractor(
"""
eval_shared_models = model_util.verify_and_update_eval_shared_models(
eval_shared_model)
if eval_shared_models:
eval_shared_models = {m.model_name: m for m in eval_shared_models}
if not eval_shared_models:
raise ValueError('No valid model(s) were provided. Please ensure that '
'EvalConfig.ModelSpec is correctly configured to enable '
'using the PredictionsExtractor.')

# pylint: disable=no-value-for-parameter
return extractor.Extractor(
stage_name=_PREDICTIONS_EXTRACTOR_STAGE_NAME,
ptransform=_ExtractPredictions(
eval_config=eval_config, eval_shared_models=eval_shared_models))
eval_config=eval_config,
eval_shared_models={m.model_name: m for m in eval_shared_models}))


@beam.ptransform_fn
@beam.typehints.with_input_types(types.Extracts)
@beam.typehints.with_output_types(types.Extracts)
def _ExtractPredictions( # pylint: disable=invalid-name
extracts: beam.pvalue.PCollection, eval_config: config_pb2.EvalConfig,
eval_shared_models: Optional[Dict[str, types.EvalSharedModel]]
) -> beam.pvalue.PCollection:
eval_shared_models: Dict[str,
types.EvalSharedModel]) -> beam.pvalue.PCollection:
"""A PTransform that adds predictions and possibly other tensors to extracts.
Args:
extracts: PCollection of extracts containing model inputs keyed by
tfma.FEATURES_KEY (if model inputs are named) or tfma.INPUTS_KEY (if model
takes raw tf.Examples as input).
eval_config: Eval config.
eval_shared_models: Shared model parameters keyed by model name or None.
eval_shared_models: Shared model parameters keyed by model name.
Returns:
PCollection of Extracts updated with the predictions.
"""

if eval_shared_models:
signature_names = {}
for spec in eval_config.model_specs:
model_name = '' if len(eval_config.model_specs) == 1 else spec.name
signature_names[model_name] = [spec.signature_name]

return (
extracts
| 'Predict' >> beam.ParDo(
model_util.ModelSignaturesDoFn(
eval_config=eval_config,
eval_shared_models=eval_shared_models,
signature_names={constants.PREDICTIONS_KEY: signature_names},
prefer_dict_outputs=False)))
else:

def extract_predictions( # pylint: disable=invalid-name
batched_extracts: types.Extracts) -> types.Extracts:
"""Extract predictions from extracts containing features."""
result = copy.copy(batched_extracts)
predictions = model_util.get_feature_values_for_model_spec_field(
list(eval_config.model_specs), 'prediction_key', 'prediction_keys',
result)
if predictions is not None:
result[constants.PREDICTIONS_KEY] = predictions
return result

return extracts | 'ExtractPredictions' >> beam.Map(extract_predictions)
signature_names = {}
for spec in eval_config.model_specs:
model_name = '' if len(eval_config.model_specs) == 1 else spec.name
signature_names[model_name] = [spec.signature_name]

return (extracts
| 'Inference' >> beam.ParDo(
model_util.ModelSignaturesDoFn(
eval_config=eval_config,
eval_shared_models=eval_shared_models,
signature_names={constants.PREDICTIONS_KEY: signature_names},
prefer_dict_outputs=False)))
Loading

0 comments on commit 47a6bda

Please sign in to comment.