Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 212486359
  • Loading branch information
tf-model-analysis-team authored and xinzha623 committed Sep 11, 2018
1 parent 556302d commit 1ea7b29
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 16 deletions.
65 changes: 49 additions & 16 deletions tensorflow_model_analysis/api/impl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,39 @@ def _ExtractOutput( # pylint: disable=invalid-name
main=_ExtractOutputDoFn.OUTPUT_TAG_METRICS)


def PredictExtractor(eval_saved_model_path, add_metrics_callbacks,
shared_handle, desired_batch_size):
# Map function which loads and runs the eval_saved_model against every
# example, yielding an types.ExampleAndExtracts containing a
# FeaturesPredictionsLabels value (where key is 'fpl').
return types.Extractor(
stage_name='Predict',
ptransform=predict_extractor.TFMAPredict(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks,
shared_handle=shared_handle,
desired_batch_size=desired_batch_size))


@beam.ptransform_fn
def Extract(examples, extractors):
"""Performs Extractions serially in provided order."""
augmented = examples

for extractor in extractors:
augmented = augmented | extractor.stage_name >> extractor.ptransform

return augmented


@beam.ptransform_fn
# No typehint for output type, since it's a multi-output DoFn result that
# Beam doesn't support typehints for yet (BEAM-3280).
def Evaluate(
# pylint: disable=invalid-name
examples,
eval_saved_model_path,
extractors = None,
add_metrics_callbacks = None,
slice_spec = None,
desired_batch_size = None,
Expand All @@ -309,6 +335,8 @@ def Evaluate(
(e.g. string containing CSV row, TensorFlow.Example, etc).
eval_saved_model_path: Path to EvalSavedModel. This directory should contain
the saved_model.pb file.
extractors: Optional list of Extractors to execute prior to slicing and
aggregating the metrics. If not provided, a default set will be run.
add_metrics_callbacks: Optional list of callbacks for adding additional
metrics to the graph. The names of the metrics added by the callbacks
should not conflict with existing metrics, or metrics added by other
Expand Down Expand Up @@ -349,24 +377,22 @@ def add_metrics_callback(features_dict, predictions_dict, labels):

shared_handle = shared.Shared()

if not extractors:
extractors = [
PredictExtractor(eval_saved_model_path, add_metrics_callbacks,
shared_handle, desired_batch_size),
]

# pylint: disable=no-value-for-parameter
return (
examples
# Our diagnostic outputs, pass types.ExampleAndExtracts throughout,
# however our aggregating functions do not use this interface.
| 'ToExampleAndExtracts' >>
beam.Map(lambda x: types.ExampleAndExtracts(example=x, extracts={}))
| Extract(extractors=extractors)

# Map function which loads and runs the eval_saved_model against every
# example, yielding an types.ExampleAndExtracts containing a
# FeaturesPredictionsLabels value (where key is 'fpl').
| 'Predict' >> predict_extractor.TFMAPredict(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks,
shared_handle=shared_handle,
desired_batch_size=desired_batch_size)

# Input: one example fpl at a time
# Input: one example at a time
# Output: one fpl example per slice key (notice that the example turns
# into n, replicated once per applicable slice key)
| 'Slice' >> slice_api.Slice(slice_spec)
Expand Down Expand Up @@ -395,6 +421,7 @@ def BuildDiagnosticTable(
# pylint: disable=invalid-name
examples,
eval_saved_model_path,
extractors = None,
desired_batch_size = None):
"""Build diagnostics for the spacified EvalSavedModel and example collection.
Expand All @@ -403,18 +430,24 @@ def BuildDiagnosticTable(
(e.g. string containing CSV row, TensorFlow.Example, etc).
eval_saved_model_path: Path to EvalSavedModel. This directory should contain
the saved_model.pb file.
extractors: Optional list of Extractors to execute prior to slicing and
aggregating the metrics. If not provided, a default set will be run.
desired_batch_size: Optional batch size for batching in Predict and
Aggregate.
Returns:
PCollection of ExampleAndExtracts
"""

if not extractors:
extractors = [
PredictExtractor(eval_saved_model_path, None, shared.Shared(),
desired_batch_size),
types.Extractor(
stage_name='ExtractFeatures',
ptransform=feature_extractor.ExtractFeatures()),
]
return (examples
| 'ToExampleAndExtracts' >>
beam.Map(lambda x: types.ExampleAndExtracts(example=x, extracts={}))
| 'Predict' >> predict_extractor.TFMAPredict(
eval_saved_model_path,
add_metrics_callbacks=None,
shared_handle=shared.Shared(),
desired_batch_size=desired_batch_size)
| 'ExtractFeatures' >> feature_extractor.ExtractFeatures())
| Extract(extractors=extractors))
5 changes: 5 additions & 0 deletions tensorflow_model_analysis/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import copy

import apache_beam as beam
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -66,6 +67,10 @@ def is_tensor(obj):
DictOfExtractedValues = Dict[Text, Any]


Extractor = NamedTuple('Extractor', [('stage_name', bytes),
('ptransform', beam.PTransform)])


class ExampleAndExtracts(
NamedTuple('ExampleAndExtracts', [('example', bytes),
('extracts', DictOfExtractedValues)])):
Expand Down

0 comments on commit 1ea7b29

Please sign in to comment.