Skip to content

Commit

Permalink
TFJS predict extractor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 327457230
  • Loading branch information
tf-model-analysis-team committed Aug 19, 2020
1 parent 9d68707 commit 3506655
Show file tree
Hide file tree
Showing 11 changed files with 594 additions and 16 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
* Added `tfma.metrics.MultiClassConfusionMatrixAtThresholds`.
* Refactoring code to compute `tfma.metrics.MultiClassConfusionMatrixPlot`
using derived computations.
* Provide support for evaluating TFJS models.

## Bug fixes and other changes

Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,11 @@ def run(self):
'six>=1.12,<2',
'tensorflow>=1.15.2,!=2.0.*,!=2.1.*,!=2.2.*,<3',
'tensorflow-metadata>=0.23,<0.24',
'tfx-bsl>=0.23,<0.24'
'tfx-bsl>=0.23,<0.24',
'tensorflowjs>=2.0.1.post1,<3',
# TODO(b/158034704): Remove prompt-toolkit pin resulted from
# tfjs -> PyInquirer dependency chain.
'prompt-toolkit>=2.0.10,<3',
],
'python_requires': '>=3.5,<4',
'packages': find_packages(),
Expand Down
1 change: 1 addition & 0 deletions tensorflow_model_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from tensorflow_model_analysis.constants import SLICE_KEY_TYPES_KEY
from tensorflow_model_analysis.constants import TF_GENERIC
from tensorflow_model_analysis.constants import TF_ESTIMATOR
from tensorflow_model_analysis.constants import TF_JS
from tensorflow_model_analysis.constants import TF_LITE
from tensorflow_model_analysis.constants import TF_KERAS
from tensorflow_model_analysis.constants import VALIDATIONS_KEY
Expand Down
23 changes: 20 additions & 3 deletions tensorflow_model_analysis/api/model_eval_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from tensorflow_model_analysis.extractors import predict_extractor
from tensorflow_model_analysis.extractors import predict_extractor_v2
from tensorflow_model_analysis.extractors import slice_key_extractor
from tensorflow_model_analysis.extractors import tfjs_predict_extractor
from tensorflow_model_analysis.extractors import tflite_predict_extractor
from tensorflow_model_analysis.extractors import unbatch_extractor
from tensorflow_model_analysis.post_export_metrics import post_export_metrics
Expand Down Expand Up @@ -465,6 +466,20 @@ def default_extractors( # pylint: disable=invalid-name
'support for mixing tf_lite and non-tf_lite models is not '
'implemented: eval_config={}'.format(eval_config))

if model_types == set([constants.TF_JS]):
return [
input_extractor.InputExtractor(eval_config=eval_config),
(custom_predict_extractor or
tfjs_predict_extractor.TFJSPredictExtractor(
eval_config=eval_config, eval_shared_model=eval_shared_model)),
slice_key_extractor.SliceKeyExtractor(
eval_config=eval_config, materialize=materialize)
]
elif constants.TF_JS in model_types:
raise NotImplementedError(
'support for mixing tf_js and non-tf_js models is not '
'implemented: eval_config={}'.format(eval_config))

elif (eval_config and model_types == set([constants.TF_ESTIMATOR]) and
all(eval_constants.EVAL_TAG in m.model_loader.tags
for m in eval_shared_models)):
Expand Down Expand Up @@ -547,8 +562,9 @@ def default_evaluators( # pylint: disable=invalid-name
eval_config = _update_eval_config_with_defaults(eval_config,
eval_shared_model)
disabled_outputs = eval_config.options.disabled_outputs.values
if _model_types(eval_shared_model) == set([constants.TF_LITE]):
# no in-graph metrics present when tflite is used.
if (_model_types(eval_shared_model) == set([constants.TF_LITE]) or
_model_types(eval_shared_model) == set([constants.TF_JS])):
# no in-graph metrics present when tflite or tfjs is used.
if eval_shared_model:
if isinstance(eval_shared_model, dict):
eval_shared_model = {
Expand Down Expand Up @@ -848,7 +864,8 @@ def is_batched_input(eval_shared_model: Optional[
model_types = _model_types(eval_shared_model)
eval_shared_models = model_util.verify_and_update_eval_shared_models(
eval_shared_model)
if model_types == set([constants.TF_LITE]):
if (model_types == set([constants.TF_LITE]) or
model_types == set([constants.TF_JS])):
return False
elif (eval_config and model_types == set([constants.TF_ESTIMATOR]) and
all(eval_constants.EVAL_TAG in m.model_loader.tags
Expand Down
34 changes: 24 additions & 10 deletions tensorflow_model_analysis/api/model_eval_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from tensorflow_model_analysis.post_export_metrics import metric_keys
from tensorflow_model_analysis.post_export_metrics import post_export_metrics
from tensorflow_model_analysis.proto import validation_result_pb2
from tensorflowjs.converters import converter as tfjs_converter

from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import schema_pb2
Expand Down Expand Up @@ -186,12 +187,14 @@ def testMixedEvalAndNonEvalSignatures(self):
data_location=data_location,
output_path=self._getTempDir())

def testMixedTFLiteAndNotTFLiteFormats(self):
@parameterized.named_parameters(('tflite', constants.TF_LITE),
('tfjs', constants.TF_JS))
def testMixedModelTypes(self, model_type):
examples = [self._makeExample(age=3.0, language='english', label=1.0)]
data_location = self._writeTFExamplesToTFRecords(examples)
eval_config = config.EvalConfig(model_specs=[
config.ModelSpec(name='model1'),
config.ModelSpec(name='model2', model_type=constants.TF_LITE)
config.ModelSpec(name='model2', model_type=model_type)
])
eval_shared_models = [
model_eval_lib.default_eval_shared_model(
Expand All @@ -204,8 +207,7 @@ def testMixedTFLiteAndNotTFLiteFormats(self):
eval_config=eval_config)
]
with self.assertRaisesRegex(
NotImplementedError,
'support for mixing tf_lite and non-tf_lite models is not implemented'):
NotImplementedError, 'support for mixing .* models is not implemented'):
model_eval_lib.run_model_analysis(
eval_config=eval_config,
eval_shared_model=eval_shared_models,
Expand Down Expand Up @@ -606,9 +608,10 @@ def testRunModelAnalysisWithModelAgnosticPredictions(self):
config.SlicingSpec(feature_keys=['language']))
self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected)

@parameterized.named_parameters(('without_tflite_conversion', False),
('with_tflite_conversion', True))
def testRunModelAnalysisWithKerasModel(self, convert_to_tflite):
@parameterized.named_parameters(('tf_keras', constants.TF_KERAS),
('tf_lite', constants.TF_LITE),
('tf_js', constants.TF_JS))
def testRunModelAnalysisWithKerasModel(self, model_type):

def _build_keras_model(eval_config, name='export_dir'):
input_layer = tf.keras.layers.Input(shape=(28 * 28,), name='data')
Expand All @@ -627,13 +630,24 @@ def _build_keras_model(eval_config, name='export_dir'):
dataset = dataset.shuffle(buffer_size=1).repeat().batch(1)
model.fit(dataset, steps_per_epoch=1)
model_location = os.path.join(self._getTempDir(), name)
if convert_to_tflite:
if model_type == constants.TF_LITE:
converter = tf.compat.v2.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tf.io.gfile.makedirs(model_location)
with tf.io.gfile.GFile(os.path.join(model_location, 'tflite'),
'wb') as f:
f.write(tflite_model)
elif model_type == constants.TF_JS:
src_model_path = tempfile.mkdtemp()
model.save(src_model_path, save_format='tf')

tfjs_converter.convert([
'--input_format=tf_saved_model',
'--saved_model_tags=serve',
'--signature_name=serving_default',
src_model_path,
model_location,
])
else:
model.save(model_location, save_format='tf')
return model_eval_lib.default_eval_shared_model(
Expand Down Expand Up @@ -699,9 +713,9 @@ def _build_keras_model(eval_config, name='export_dir'):
eval_config = config.EvalConfig(
model_specs=[config.ModelSpec(label_key='label')],
metrics_specs=[metrics_spec])
if convert_to_tflite:
if model_type != constants.TF_KERAS:
for s in eval_config.model_specs:
s.model_type = constants.TF_LITE
s.model_type = model_type

model, model_location = _build_keras_model(eval_config)
baseline, baseline_model_location = _build_keras_model(
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_model_analysis/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
TF_KERAS = 'tf_keras'
TF_GENERIC = 'tf_generic'
TF_LITE = 'tf_lite'
VALID_TF_MODEL_TYPES = (TF_GENERIC, TF_ESTIMATOR, TF_KERAS, TF_LITE)
TF_JS = 'tf_js'
VALID_TF_MODEL_TYPES = (TF_GENERIC, TF_ESTIMATOR, TF_KERAS, TF_LITE, TF_JS)

# LINT.IfChange
METRICS_NAMESPACE = 'tfx.ModelAnalysis'
Expand Down
Loading

0 comments on commit 3506655

Please sign in to comment.