Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501615454
  • Loading branch information
tf-model-analysis-team authored and tfx-copybara committed Jan 12, 2023
1 parent f990553 commit 1c04726
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Predictions extractor for using Tfx-Bsl Bulk Inference."""
"""Predictions extractor for using TFX-BSL Bulk Inference."""

import copy
from typing import Dict, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -46,22 +46,27 @@ def __init__(self,
"""Converts TFMA config into library-specific configuration.
Args:
model_specs: TFMA ModelSpec config to be translated to ServoBeam Config.
model_specs: TFMA ModelSpec config to be translated to TFX-BSL Config.
name_to_eval_shared_model: Map of model name to associated EvalSharedModel
objed.
object.
"""
super().__init__()
model_names = []
inference_specs = []
for model_spec in model_specs:
eval_shared_model = model_util.get_eval_shared_model(
model_spec.name, name_to_eval_shared_model)
inference_spec_type = model_spec_pb2.InferenceSpecType()
inference_spec_type.saved_model_spec.model_path = eval_shared_model.model_path
inference_spec_type.saved_model_spec.tag[:] = eval_shared_model.model_loader.tags
inference_spec_type.saved_model_spec.signature_name[:] = [
model_spec.signature_name
]
inference_spec_type = model_spec_pb2.InferenceSpecType(
saved_model_spec=model_spec_pb2.SavedModelSpec(
model_path=eval_shared_model.model_path,
tag=eval_shared_model.model_loader.tags,
signature_name=[model_spec.signature_name],
),
batch_parameters=model_spec_pb2.BatchParameters(
min_batch_size=model_spec.inference_batch_size,
max_batch_size=model_spec.inference_batch_size,
),
)
model_names.append(model_spec.name)
inference_specs.append(inference_spec_type)
self._aligned_model_names = tuple(model_names)
Expand Down Expand Up @@ -103,8 +108,8 @@ def TfxBslPredictionsExtractor(
models (multi-model evaluation) or None (predictions obtained from
features).
output_batch_size: Sets a static output batch size for bulk inference. Note:
this is not implemented for Tfx-Bsl inference and only affects the
rebatched output batch size.
this only affects the rebatched output batch size to set inference batch
size set ModelSpec.inference_batch_size.
Returns:
Extractor for extracting predictions.
Expand Down Expand Up @@ -132,8 +137,6 @@ def TfxBslPredictionsExtractor(
model_specs.append(model_spec)

tfx_bsl_inference_ptransform = inference_base.RunInference(
# TODO(b/260887130): TfxBsl Bulk Inference doesn't support batch_size yet,
# but will in the near future. See bug.
inference_ptransform=TfxBslInferenceWrapper(model_specs,
name_to_eval_shared_model),
output_batch_size=output_batch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,116 @@ def check_result(got):

util.assert_that(result, check_result)

def testInferenceBatchSize(self):
temp_export_dir = self._getExportDir()
# Running pyformat results in a lint error. Formating the lint error breaks
# the pyformat presubmit. This style matches the other tests.
# pyformat: disable
export_dir, _ = (
fixed_prediction_estimator_extra_fields
.simple_fixed_prediction_estimator_extra_fields(temp_export_dir, None))
# pyformat: enable

examples = [
self._makeExample(
prediction=0.2,
label=1.0,
fixed_int=1,
fixed_float=1.0,
fixed_string='fixed_string1',
),
self._makeExample(
prediction=0.8,
label=0.0,
fixed_int=1,
fixed_float=1.0,
fixed_string='fixed_string2',
),
self._makeExample(
prediction=0.5,
label=0.0,
fixed_int=2,
fixed_float=1.0,
fixed_string='fixed_string3',
),
]
num_examples = len(examples)

eval_config = config_pb2.EvalConfig(
model_specs=[
config_pb2.ModelSpec(
inference_batch_size=num_examples,
)
]
)
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING]
)
tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor(
eval_config,
text_format.Parse(
"""
feature {
name: "prediction"
type: FLOAT
}
feature {
name: "label"
type: FLOAT
}
feature {
name: "fixed_int"
type: INT
}
feature {
name: "fixed_float"
type: FLOAT
}
feature {
name: "fixed_string"
type: BYTES
}
""",
schema_pb2.Schema(),
),
)

prediction_extractor = (
tfx_bsl_predictions_extractor.TfxBslPredictionsExtractor(
eval_config=eval_config,
eval_shared_model=eval_shared_model,
output_batch_size=num_examples,
)
)

with beam.Pipeline() as pipeline:
# pylint: disable=no-value-for-parameter
result = (
pipeline
| 'Create'
>> beam.Create(
[e.SerializeToString() for e in examples], reshuffle=False
)
| 'BatchExamples' >> tfx_io.BeamSource(batch_size=num_examples)
| 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
| feature_extractor.stage_name >> feature_extractor.ptransform
| prediction_extractor.stage_name >> prediction_extractor.ptransform
)
# pylint: enable=no-value-for-parameter

def check_result(got):
try:
self.assertLen(got, 1)
self.assertIn(constants.PREDICTIONS_KEY, got[0])
self.assertAllClose(
np.array([[0.2], [0.8], [0.5]]), got[0][constants.PREDICTIONS_KEY]
)

except AssertionError as err:
raise util.BeamAssertException(err)

util.assert_that(result, check_result)

def testNoDefinedBatchSize(self):
"""Simple test to cover batch_size=None code path."""
temp_export_dir = self._getExportDir()
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_model_analysis/proto/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ message ModelSpec {
// Batch size used by the inference implementation. This batch size is only
// used for inference with this model. It does not affect the batch size of
// other models, and it does not affect the batch size used in the rest of the
// pipeline.
// pipeline. This is implemented for the ServoBeamPredictionsExtractor and
// TfxBslPredictionsExtractor.
int32 inference_batch_size = 15;

reserved 1, 4;
Expand Down

0 comments on commit 1c04726

Please sign in to comment.