Skip to content

Commit

Permalink
Part IV of batched tensor updates.
Browse files Browse the repository at this point in the history
Update Keras utils to use already parsed features and transformed features instead of re-parsing the inputs.

PiperOrigin-RevId: 441952555
  • Loading branch information
mdreves authored and tfx-copybara committed Apr 15, 2022
1 parent 2b0b6da commit f88dd53
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 122 deletions.
13 changes: 3 additions & 10 deletions tensorflow_model_analysis/api/model_eval_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,7 @@ def default_evaluators( # pylint: disable=invalid-name
min_slice_size: int = 1,
serialize: bool = False,
random_seed_for_testing: Optional[int] = None,
config_version: Optional[int] = None,
tensor_adapter_config: Optional[tensor_adapter.TensorAdapterConfig] = None
) -> List[evaluator.Evaluator]:
config_version: Optional[int] = None) -> List[evaluator.Evaluator]:
"""Returns the default evaluators for use in ExtractAndEvaluate.
Args:
Expand All @@ -664,9 +662,6 @@ def default_evaluators( # pylint: disable=invalid-name
be explicitly set by users. It is only intended to be used in cases where
the provided eval_config was generated internally, and thus not a reliable
indicator of user intent.
tensor_adapter_config: Tensor adapter config which specifies how to obtain
tensors from the Arrow RecordBatch. If None, an attempt will be made to
create the tensors using default TensorRepresentations.
"""
disabled_outputs = []
if eval_config:
Expand Down Expand Up @@ -717,8 +712,7 @@ def default_evaluators( # pylint: disable=invalid-name
eval_config=eval_config,
eval_shared_model=eval_shared_model,
schema=schema,
random_seed_for_testing=random_seed_for_testing,
tensor_adapter_config=tensor_adapter_config)
random_seed_for_testing=random_seed_for_testing)
]


Expand Down Expand Up @@ -1141,8 +1135,7 @@ def ExtractEvaluateAndWriteResults( # pylint: disable=invalid-name
eval_shared_model=eval_shared_model,
random_seed_for_testing=random_seed_for_testing,
schema=schema,
config_version=config_version,
tensor_adapter_config=tensor_adapter_config)
config_version=config_version)

for v in evaluators:
evaluator.verify_evaluator(v, extractors)
Expand Down
72 changes: 26 additions & 46 deletions tensorflow_model_analysis/evaluators/keras_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,12 @@
from tensorflow_model_analysis.proto import config_pb2
from tensorflow_model_analysis.utils import model_util
from tensorflow_model_analysis.utils import util
from tfx_bsl.coders import example_coder
from tfx_bsl.tfxio import tensor_adapter


def metric_computations_using_keras_saved_model(
model_name: str,
model_loader: types.ModelLoader,
eval_config: Optional[config_pb2.EvalConfig],
tensor_adapter_config: Optional[tensor_adapter.TensorAdapterConfig] = None,
batch_size: Optional[int] = None) -> metric_types.MetricComputations:
"""Returns computations for computing metrics natively using keras.
Expand All @@ -46,8 +43,6 @@ def metric_computations_using_keras_saved_model(
model_loader: Loader for shared model containing keras saved model to use
for metric computations.
eval_config: Eval config.
tensor_adapter_config: Tensor adapter config which specifies how to obtain
tensors from the Arrow RecordBatch.
batch_size: Batch size to use during evaluation (testing only).
"""
model = model_loader.load()
Expand Down Expand Up @@ -82,16 +77,19 @@ def metric_computations_using_keras_saved_model(
else:
output_names = []
keys = _metric_keys(model.metrics, model_name, output_names)
specs = model_util.get_input_specs(model_name, signature_name=None)
feature_keys = list(specs.keys()) if specs else []
return [
metric_types.MetricComputation(
keys=keys,
# TODO(b/178158073): By using inputs instead of batched features we
# incur the cost of having to parse the inputs a second time. In
# addition, transformed features (i.e. TFT, KPL) are not supported.
preprocessor=metric_types.InputPreprocessor(),
preprocessor=metric_types.StandardMetricInputsPreprocessorList([
metric_types.FeaturePreprocessor(
feature_keys=feature_keys, model_names=[model_name]),
metric_types.TransformedFeaturePreprocessor(
feature_keys=feature_keys, model_names=[model_name])
]),
combiner=_KerasEvaluateCombiner(keys, model_name, model_loader,
eval_config, tensor_adapter_config,
batch_size))
eval_config, batch_size))
]


Expand Down Expand Up @@ -372,25 +370,9 @@ def __init__(self,
model_name: str,
model_loader: types.ModelLoader,
eval_config: Optional[config_pb2.EvalConfig],
tensor_adapter_config: Optional[
tensor_adapter.TensorAdapterConfig] = None,
desired_batch_size: Optional[int] = None):
super().__init__(keys, model_name, model_loader, eval_config,
desired_batch_size, 'keras_evaluate_combine')
self._tensor_adapter_config = tensor_adapter_config
self._tensor_adapter = None
self._decoder = None

def setup(self):
super().setup()
# TODO(b/180125126): Re-enable use of passed in TensorAdapter after bug
# requiring matching schema's is fixed.
# if self._tensor_adapter is None and
# self._tensor_adapter_config is not None:
# self._tensor_adapter = tensor_adapter.TensorAdapter(
# self._tensor_adapter_config)
if self._decoder is None:
self._decoder = example_coder.ExamplesToRecordBatchDecoder()

def _metrics(self) -> Iterable[tf.keras.metrics.Metric]:
return self._model.metrics
Expand Down Expand Up @@ -424,20 +406,29 @@ def _add_input(
output_name,
flatten=False,
example_weighted=True))
accumulator.add_input(i, element.inputs if i == 0 else None, labels,
example_weights)

if i == 0:
if element.transformed_features:
features = {}
features.update(element.features)
features.update(element.transformed_features)
else:
features = element.features
else:
features = None
accumulator.add_input(i, features, labels, example_weights)

return accumulator

def _update_state(self,
accumulator: tf_metric_accumulators.TFMetricsAccumulator):
serialized_examples = None
features = {}
labels = {}
example_weights = {}
for i, output_name in enumerate(self._output_names):
e, l, w = accumulator.get_inputs(i)
f, l, w = accumulator.get_inputs(i)
if i == 0:
serialized_examples = e
features = util.merge_extracts(f)
if not output_name and len(self._output_names) > 1:
# The empty output_name for multi-output models is not used for inputs.
continue
Expand All @@ -451,25 +442,14 @@ def _update_state(self,
# Single-output models don't use dicts.
labels = next(iter(labels.values()))
example_weights = next(iter(example_weights.values()))
# TODO(b/178158073): Remove record batch and use features directly.
# Serialized examples may be encoded as [np.array(['...']), ...], this will
# convert them to a list of strings.
record_batch = self._decoder.DecodeBatch(
np.array(serialized_examples, dtype=object).squeeze().tolist())
tensor_representations = None
if self._tensor_adapter:
tensor_representations = self._tensor_adapter.tensor_representations
tensor_values = util.record_batch_to_tensor_values(record_batch,
tensor_representations)
input_specs = model_util.get_input_specs(self._model, signature_name=None)
inputs = model_util.get_inputs(tensor_values, input_specs)
inputs = model_util.get_inputs(features, input_specs)
if inputs is None:
raise ValueError('unable to prepare inputs for evaluation: '
'input_specs={}, record_batch={}'.format(
input_specs, record_batch))
f'input_specs={input_specs}, features={features}')
self._model.evaluate(
x=inputs,
y=labels,
batch_size=record_batch.num_rows,
batch_size=util.batch_size(features),
verbose=0,
sample_weight=example_weights)
76 changes: 47 additions & 29 deletions tensorflow_model_analysis/evaluators/keras_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,17 @@ def testWithBinaryClassification(self, sequential_model, add_custom_metrics):
labels=np.array([0.0]),
predictions=np.array([1.0]),
example_weights=np.array([0.5]),
input=self._makeExample(output=1.0).SerializeToString()),
features={'output': np.array([1.0])}),
metric_types.StandardMetricInputs(
labels=np.array([1.0]),
predictions=np.array([0.7]),
example_weights=np.array([0.7]),
input=self._makeExample(output=0.7).SerializeToString()),
features={'output': np.array([0.7])}),
metric_types.StandardMetricInputs(
labels=np.array([0.0]),
predictions=np.array([0.5]),
example_weights=np.array([0.9]),
input=self._makeExample(output=0.5).SerializeToString())
features={'output': np.array([0.5])})
]

expected_values = {
Expand Down Expand Up @@ -282,8 +282,10 @@ def testWithBinaryClassificationMultiOutput(self, add_custom_metrics):
'output_1': np.array([0.5]),
'output_2': np.array([0.1]),
},
input=self._makeExample(output_1=1.0,
output_2=0.0).SerializeToString()),
features={
'output_1': np.array([1.0]),
'output_2': np.array([0.0])
}),
metric_types.StandardMetricInputs(
labels={
'output_1': np.array([1.0]),
Expand All @@ -297,8 +299,10 @@ def testWithBinaryClassificationMultiOutput(self, add_custom_metrics):
'output_1': np.array([0.7]),
'output_2': np.array([0.4]),
},
input=self._makeExample(output_1=0.7,
output_2=0.3).SerializeToString()),
features={
'output_1': np.array([0.7]),
'output_2': np.array([0.3])
}),
metric_types.StandardMetricInputs(
labels={
'output_1': np.array([0.0]),
Expand All @@ -312,8 +316,10 @@ def testWithBinaryClassificationMultiOutput(self, add_custom_metrics):
'output_1': np.array([0.9]),
'output_2': np.array([0.7]),
},
input=self._makeExample(output_1=0.5,
output_2=0.8).SerializeToString())
features={
'output_1': np.array([0.5]),
'output_2': np.array([0.8])
})
]

expected_values = {
Expand Down Expand Up @@ -422,26 +428,30 @@ def testWithMultiClassClassification(self, sequential_model,
labels=np.array([0, 0, 1, 0, 0]),
predictions=np.array([0.1, 0.2, 0.1, 0.25, 0.35]),
example_weights=np.array([0.5]),
input=self._makeExample(
output=[0.1, 0.2, 0.1, 0.25, 0.35]).SerializeToString()),
features={
'output': np.array([0.1, 0.2, 0.1, 0.25, 0.35]),
}),
metric_types.StandardMetricInputs(
labels=np.array([0, 1, 0, 0, 0]),
predictions=np.array([0.2, 0.3, 0.05, 0.15, 0.3]),
example_weights=np.array([0.7]),
input=self._makeExample(
output=[0.2, 0.3, 0.05, 0.15, 0.3]).SerializeToString()),
features={
'output': np.array([0.2, 0.3, 0.05, 0.15, 0.3]),
}),
metric_types.StandardMetricInputs(
labels=np.array([0, 0, 0, 1, 0]),
predictions=np.array([0.01, 0.2, 0.09, 0.5, 0.2]),
example_weights=np.array([0.9]),
input=self._makeExample(
output=[0.01, 0.2, 0.09, 0.5, 0.2]).SerializeToString()),
features={
'output': np.array([0.01, 0.2, 0.09, 0.5, 0.2]),
}),
metric_types.StandardMetricInputs(
labels=np.array([0, 1, 0, 0, 0]),
predictions=np.array([0.3, 0.2, 0.05, 0.4, 0.05]),
example_weights=np.array([0.3]),
input=self._makeExample(
output=[0.3, 0.2, 0.05, 0.4, 0.05]).SerializeToString())
features={
'output': np.array([0.3, 0.2, 0.05, 0.4, 0.05]),
})
]

# Unweighted:
Expand Down Expand Up @@ -536,9 +546,14 @@ def testWithMultiClassClassificationMultiOutput(self, add_custom_metrics):
'output_1': np.array([0.5]),
'output_2': np.array([0.5]),
},
input=self._makeExample(
output_1=[0.1, 0.2, 0.1, 0.25, 0.35],
output_2=[0.1, 0.2, 0.1, 0.25, 0.35]).SerializeToString()),
features={
'output_1': np.array([0.0, 0.0, 0.0, 0.0, 0.0]),
'output_2': np.array([0.0, 0.0, 0.0, 0.0, 0.0])
},
transformed_features={
'output_1': np.array([0.1, 0.2, 0.1, 0.25, 0.35]),
'output_2': np.array([0.1, 0.2, 0.1, 0.25, 0.35])
}),
metric_types.StandardMetricInputs(
labels={
'output_1': np.array([0, 1, 0, 0, 0]),
Expand All @@ -552,9 +567,10 @@ def testWithMultiClassClassificationMultiOutput(self, add_custom_metrics):
'output_1': np.array([0.7]),
'output_2': np.array([0.7]),
},
input=self._makeExample(
output_1=[0.2, 0.3, 0.05, 0.15, 0.3],
output_2=[0.2, 0.3, 0.05, 0.15, 0.3]).SerializeToString()),
features={
'output_1': np.array([0.2, 0.3, 0.05, 0.15, 0.3]),
'output_2': np.array([0.2, 0.3, 0.05, 0.15, 0.3])
}),
metric_types.StandardMetricInputs(
labels={
'output_1': np.array([0, 0, 0, 1, 0]),
Expand All @@ -568,9 +584,10 @@ def testWithMultiClassClassificationMultiOutput(self, add_custom_metrics):
'output_1': np.array([0.9]),
'output_2': np.array([0.9]),
},
input=self._makeExample(
output_1=[0.01, 0.2, 0.09, 0.5, 0.2],
output_2=[0.01, 0.2, 0.09, 0.5, 0.2]).SerializeToString()),
features={
'output_1': np.array([0.01, 0.2, 0.09, 0.5, 0.2]),
'output_2': np.array([0.01, 0.2, 0.09, 0.5, 0.2])
}),
metric_types.StandardMetricInputs(
labels={
'output_1': np.array([0, 1, 0, 0, 0]),
Expand All @@ -584,9 +601,10 @@ def testWithMultiClassClassificationMultiOutput(self, add_custom_metrics):
'output_1': np.array([0.3]),
'output_2': np.array([0.3]),
},
input=self._makeExample(
output_1=[0.3, 0.2, 0.05, 0.4, 0.05],
output_2=[0.3, 0.2, 0.05, 0.4, 0.05]).SerializeToString())
features={
'output_1': np.array([0.3, 0.2, 0.05, 0.4, 0.05]),
'output_2': np.array([0.3, 0.2, 0.05, 0.4, 0.05])
})
]

# Unweighted:
Expand Down
Loading

0 comments on commit f88dd53

Please sign in to comment.