Skip to content

Commit

Permalink
[TF Hub API][TF FE] Support TF Keras Model OOB without example_input (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#19892)

* [TF Hub] Cover TF Hub use cases with adoption to OpenVINO

This is necessarily to demonstrate support of models programmed with TF Hub API
through OV notebooks.

Signed-off-by: Kazantsev, Roman <[email protected]>

* Preserve original keras input and output tensor names

* Add tests with TF Hub API models

* No KerasLayer handling

* Handle specific signature

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Sep 18, 2023
1 parent a4cbac3 commit df19699
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 21 deletions.
105 changes: 100 additions & 5 deletions src/bindings/python/src/openvino/frontend/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_imported_module_version(imported_module):
for attr in version_attrs:
installed_version = getattr(imported_module, attr, None)
if isinstance(installed_version, str):
return installed_version
return installed_version
else:
installed_version = None

Expand Down Expand Up @@ -98,7 +98,8 @@ def get_environment_setup(framework):

def trace_tf_model_if_needed(input_model, placeholder_shapes, placeholder_data_types, example_input):
import tensorflow as tf
if not isinstance(input_model, (tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
if not isinstance(input_model,
(tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
return input_model
return trace_tf_model(input_model, placeholder_shapes, placeholder_data_types, example_input)

Expand Down Expand Up @@ -175,6 +176,65 @@ def get_concrete_func(tf_function, example_input, input_needs_packing, error_mes
return concrete_func


def create_generic_function_from_keras_model(keras_model):
import tensorflow as tf
assert isinstance(keras_model, tf.keras.Model), \
"[TensorFlow Frontend] internal error: the input model must be of Keras model type"
if not hasattr(keras_model, 'input') or getattr(keras_model, 'input') is None:
return None
keras_input_signature = getattr(keras_model, 'input')
tf_input_signature = None
wrapper_function = None
if isinstance(keras_input_signature, dict):
tf_input_signature = []
for tensor_name, tensor_spec in keras_input_signature.items():
tf_input_signature.append(tf.TensorSpec(shape=tensor_spec.shape,
dtype=tensor_spec.dtype,
name=tensor_name))
elif isinstance(keras_input_signature, list):
tf_input_signature = []
for tensor_spec in keras_input_signature:
tf_input_signature.append(tf.TensorSpec(shape=tensor_spec.shape,
dtype=tensor_spec.dtype,
name=tensor_spec.name))
else:
try:
# single KerasTensor case
tf_input_signature = []
tf_input_signature.append(tf.TensorSpec(shape=keras_input_signature.shape,
dtype=keras_input_signature.dtype,
name=keras_input_signature.name))
except:
tf_input_signature = None
if tf_input_signature is not None:
@tf.function(input_signature=tf_input_signature)
def wrapper_function_dict(*args):
input_dict = {}
for ind, tensor_spec in enumerate(tf_input_signature):
input_dict[tensor_spec.name] = args[ind]
outputs = keras_model(input_dict)
# need to wrap the output into dictionary
# it helps to preserve original keras tensor names
post_outputs = {}
if isinstance(outputs, dict):
for output_name, output_value in outputs.items():
post_outputs[output_name] = output_value
else:
try:
if isinstance(outputs, list) and isinstance(keras_model.outputs, list) and \
len(outputs) == len(keras_model.outputs):
for output_value, output_tensor in zip(outputs, keras_model.outputs):
post_outputs[output_tensor.name] = output_value
else:
post_outputs[keras_model.output.name] = outputs
except:
post_outputs = outputs
return post_outputs

wrapper_function = wrapper_function_dict
return wrapper_function


def trace_tf_model(model, input_shapes, input_types, example_input):
import tensorflow as tf
if isinstance(model.__call__, tf.types.experimental.GenericFunction):
Expand All @@ -183,12 +243,25 @@ def trace_tf_model(model, input_shapes, input_types, example_input):
elif isinstance(model, tf.types.experimental.GenericFunction):
tf_function = model
input_needs_packing = False
elif isinstance(model, tf.keras.Model):
tf_function = create_generic_function_from_keras_model(model)
if tf_function is not None:
input_needs_packing = False
else:
# Wrap model to tf.Function.
# In this case we loose input/output tensor names.
@tf.function
def tf_function(args):
return model(*args)

input_needs_packing = True
else:
# Wrap model to tf.Function.
# In this case we loose input/output tensor names.
@tf.function
def tf_function(args):
return model(*args)

input_needs_packing = True

if example_input is not None:
Expand Down Expand Up @@ -216,7 +289,8 @@ def tf_function(args):
def type_supported_by_tf_fe(input_model):
import tensorflow as tf
# Types that require tracing
if isinstance(input_model, (tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
if isinstance(input_model,
(tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
return True
# Types that do not require tracing
if isinstance(input_model, (tf.Graph, tf.types.experimental.ConcreteFunction)):
Expand Down Expand Up @@ -246,7 +320,15 @@ def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_t
if func_input.dtype == tf.resource:
continue
internal_tensor_names.append(func_input.name)
if len(input_model.structured_input_signature) > 1 and \
if len(input_model.structured_input_signature) > 0 and \
len(internal_tensor_names) == len(input_model.structured_input_signature[0]):
for internal_name, tensor_spec in zip(internal_tensor_names, input_model.structured_input_signature[0]):
input_names_map = input_names_map or {}
if not isinstance(tensor_spec, tf.TensorSpec):
input_names_map = None
break
input_names_map[internal_name] = tensor_spec.name
elif len(input_model.structured_input_signature) > 1 and \
len(internal_tensor_names) == len(input_model.structured_input_signature[1]):
external_tensor_names = sorted(input_model.structured_input_signature[1].keys())
for internal_name, external_name in zip(internal_tensor_names, external_tensor_names):
Expand All @@ -262,6 +344,19 @@ def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_t
for external_name, internal_name in zip(external_names, internal_names):
output_names_map = output_names_map or {}
output_names_map[internal_name] = external_name
else:
for external_name, internal_tensor in input_model.structured_outputs.items():
internal_tf_tensor = None
if isinstance(internal_tensor, tf.Tensor):
internal_tf_tensor = internal_tensor
if isinstance(internal_tensor, list) and len(internal_tensor) > 0 and \
isinstance(internal_tensor[0], tf.Tensor):
internal_tf_tensor = internal_tensor[0]
if internal_tf_tensor is None:
output_names_map = None
break
output_names_map = output_names_map or {}
output_names_map[internal_tf_tensor.name] = external_name
return GraphIteratorTFGraph(input_model.graph, share_weights, False, input_names_map, output_names_map)
raise Exception("Could not wrap model of type {} to GraphIteratorTFGraph.".format(type(input_model)))

Expand All @@ -271,7 +366,7 @@ def extract_model_graph(argv):
import tensorflow as tf
trackable_is_imported = False
try:
from tensorflow.python.training.tracking.base import Trackable # pylint: disable=no-name-in-module,import-error
from tensorflow.python.training.tracking.base import Trackable # pylint: disable=no-name-in-module,import-error
trackable_is_imported = True
except:
log.warning("Could not import tensorflow.python.training.tracking.base.Trackable type.")
Expand Down
12 changes: 9 additions & 3 deletions tests/model_hub_tests/models_hub_common/test_convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,15 @@ def prepare_input(self, input_shape, input_type):
assert False, "Unsupported type {}".format(input_type)

def prepare_inputs(self, inputs_info):
inputs = {}
for input_name, input_shape, input_type in inputs_info:
inputs[input_name] = self.prepare_input(input_shape, input_type)
if len(inputs_info) > 0 and inputs_info[0] == 'list':
inputs = []
inputs_info = inputs_info[1:]
for input_name, input_shape, input_type in inputs_info:
inputs.append(self.prepare_input(input_shape, input_type))
else:
inputs = {}
for input_name, input_shape, input_type in inputs_info:
inputs[input_name] = self.prepare_input(input_shape, input_type)
return inputs

def convert_model(self, model_obj):
Expand Down
63 changes: 63 additions & 0 deletions tests/model_hub_tests/tf_hub_tests/test_tf_hub_api_notebooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import tensorflow as tf
import tensorflow_hub as hub
from models_hub_common.test_convert_model import TestConvertModel
from tf_hub_tests.utils import get_input_info


class TestTFHubApiNotebooks(TestConvertModel):
def load_model(self, model_name, model_link):
if model_name == 'mobilenet_v2_100_224_dict':
image = tf.keras.layers.Input(shape=(224, 224, 3), dtype=tf.float32, name="image")
feature_vector = hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/5",
trainable=False)(image)
softmax = tf.keras.layers.Dense(20, activation='softmax')(feature_vector)
classification_model = tf.keras.Model(inputs={'image': image}, outputs={'softmax': softmax})
return classification_model
elif model_name == 'mobilenet_v2_100_224_list':
image = tf.keras.layers.Input(shape=(224, 224, 3), dtype=tf.float32, name="image")
feature_vector = hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/5",
trainable=False)(image)
softmax = tf.keras.layers.Dense(20, activation='softmax')(feature_vector)
classification_model = tf.keras.Model(inputs=[image], outputs=[softmax])
return classification_model
else:
raise "Unknown input model: {}".format(model_name)

def get_inputs_info(self, keras_model):
inputs_info = []
if isinstance(keras_model.input, dict):
for input_name, input_tensor in keras_model.input.items():
inputs_info.append(get_input_info(input_tensor, input_name))
elif isinstance(keras_model.input, list):
inputs_info.append('list')
for input_tensor in keras_model.input:
inputs_info.append(get_input_info(input_tensor, input_tensor.name))
else:
inputs_info.append('list')
input_tensor = keras_model.input
inputs_info.append(get_input_info(input_tensor, input_tensor.name))
return inputs_info

def infer_fw_model(self, model_obj, inputs):
outputs = model_obj(inputs)
if isinstance(outputs, dict):
post_outputs = {}
for out_name, out_value in outputs.items():
post_outputs[out_name] = out_value.numpy()
elif isinstance(outputs, list):
post_outputs = []
for out_value in outputs:
post_outputs.append(out_value.numpy())
else:
post_outputs = [outputs.numpy()]

return post_outputs

@pytest.mark.precommit
@pytest.mark.parametrize("model_name", ['mobilenet_v2_100_224_dict', 'mobilenet_v2_100_224_list'])
def test_tf_hub_api_notebook1(self, model_name, ie_device):
self.run(model_name, '', ie_device)
14 changes: 1 addition & 13 deletions tests/model_hub_tests/tf_hub_tests/test_tf_hub_convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import shutil

import numpy as np
import pytest
import tensorflow as tf
import tensorflow_hub as hub
Expand All @@ -14,6 +13,7 @@
from models_hub_common.constants import tf_hub_cache_dir
from models_hub_common.test_convert_model import TestConvertModel
from models_hub_common.utils import get_models_list
from tf_hub_tests.utils import type_map


class TestTFHubConvertModel(TestConvertModel):
Expand Down Expand Up @@ -49,18 +49,6 @@ def get_inputs_info(self, model_obj):
except ValueError:
# unknown rank case
pass
type_map = {
tf.float64: np.float64,
tf.float32: np.float32,
tf.int8: np.int8,
tf.int16: np.int16,
tf.int32: np.int32,
tf.int64: np.int64,
tf.uint8: np.uint8,
tf.uint16: np.uint16,
tf.string: str,
tf.bool: bool,
}
if input_info.dtype == tf.resource:
# skip inputs corresponding to variables
continue
Expand Down
33 changes: 33 additions & 0 deletions tests/model_hub_tests/tf_hub_tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import tensorflow as tf

type_map = {
tf.float64: np.float64,
tf.float32: np.float32,
tf.int8: np.int8,
tf.int16: np.int16,
tf.int32: np.int32,
tf.int64: np.int64,
tf.uint8: np.uint8,
tf.uint16: np.uint16,
tf.string: str,
tf.bool: bool,
}


def get_input_info(input_tensor, input_name):
input_shape = []
try:
for dim in input_tensor.shape.as_list():
if dim is None:
input_shape.append(1)
else:
input_shape.append(dim)
except ValueError:
# unknown rank case
pass
assert input_tensor.dtype in type_map, "Unsupported input type: {}".format(input_tensor.dtype)
return input_name, input_shape, type_map[input_tensor.dtype]

0 comments on commit df19699

Please sign in to comment.