forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Eye-9] Extend MO with Eye-9 op (openvinotoolkit#11555)
- Loading branch information
Showing
9 changed files
with
416 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
import pytest | ||
import tensorflow as tf | ||
from common.tf_layer_test_class import CommonTFLayerTest | ||
from common.utils.tf_utils import permute_nchw_to_nhwc | ||
|
||
|
||
class TestTFEye(CommonTFLayerTest): | ||
eye_output_type_param = np.float32 | ||
|
||
# Overload inputs generation to fill dummy Add input with 0 | ||
def _prepare_input(self, inputs_dict): | ||
for input in inputs_dict.keys(): | ||
inputs_dict[input] = np.zeros(inputs_dict[input]).astype(self.eye_output_type_param) | ||
return inputs_dict | ||
|
||
|
||
def create_tf_eye_net(self, num_rows, num_columns, batch_shape, output_type): | ||
tf.compat.v1.reset_default_graph() | ||
|
||
# Create the graph and model | ||
with tf.compat.v1.Session() as sess: | ||
tf.compat.v1.global_variables_initializer() | ||
# batch_shape_input = tf.constant(constant_value) | ||
if output_type is None: | ||
eye = tf.eye(num_rows=num_rows, num_columns=num_columns, batch_shape=batch_shape) | ||
else: | ||
self.eye_output_type_param = output_type | ||
eye = tf.eye(num_rows=num_rows, num_columns=num_columns, batch_shape=batch_shape, dtype=tf.as_dtype(output_type)) | ||
|
||
# Dummy Add layer to prevent fully const network | ||
input_zero = tf.compat.v1.placeholder(tf.as_dtype(self.eye_output_type_param), [1], 'Input') | ||
add = tf.add(eye, input_zero) | ||
|
||
tf_net = sess.graph_def | ||
|
||
ref_net = None | ||
return tf_net, ref_net | ||
|
||
test_data = [dict(num_rows=5, num_columns=None, batch_shape=None, output_type=None), | ||
dict(num_rows=5, num_columns=5, batch_shape=[2, 3], output_type=np.float32), | ||
dict(num_rows=5, num_columns=5, batch_shape=[2, 3], output_type=np.float32), | ||
dict(num_rows=5, num_columns=5, batch_shape=[2, 3], output_type=np.float16), | ||
dict(num_rows=5, num_columns=5, batch_shape=[2, 3], output_type=np.int32), | ||
dict(num_rows=5, num_columns=5, batch_shape=[2, 3], output_type=np.int8), | ||
dict(num_rows=8, num_columns=5, batch_shape=None, output_type=np.float32), | ||
dict(num_rows=5, num_columns=8, batch_shape=None, output_type=np.float32), | ||
dict(num_rows=2, num_columns=2, batch_shape=None, output_type=np.float32), | ||
dict(num_rows=6, num_columns=6, batch_shape=[2], output_type=np.float32)] | ||
|
||
@pytest.mark.parametrize("params", test_data) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
def test_tf_eye(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, | ||
api_2=True): | ||
if ie_device == 'GPU': | ||
pytest.skip("Roll is not supported on GPU") | ||
self._test(*self.create_tf_eye_net(**params), ie_device, | ||
precision, | ||
temp_dir=temp_dir, ir_version=ir_version, use_new_frontend=use_new_frontend, | ||
api_2=api_2, **params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
|
||
from openvino.tools.mo.ops.eye import MXEye | ||
from openvino.tools.mo.front.extractor import FrontExtractorOp | ||
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs | ||
|
||
|
||
class EyeExtractor(FrontExtractorOp): | ||
op = '_npi_eye' | ||
enabled = True | ||
|
||
@classmethod | ||
def extract(cls, node): | ||
attrs = get_mxnet_layer_attrs(node.symbol_dict) | ||
num_rows = attrs.int("N") | ||
num_columns = attrs.int("M", num_rows) | ||
if num_columns is None: | ||
num_columns = num_rows | ||
diagonal_index = attrs.int("k", 0) | ||
out_type = attrs.dtype("dtype", np.float32) | ||
new_attrs = {'num_rows': num_rows, 'num_columns': num_columns, 'diagonal_index': diagonal_index, 'output_type': out_type} | ||
MXEye.update_node_stat(node, new_attrs) | ||
return cls.enabled |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from openvino.tools.mo.front.common.replacement import FrontReplacementPattern | ||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array | ||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs | ||
from openvino.tools.mo.graph.graph import Graph | ||
from openvino.tools.mo.ops.eye import Eye | ||
from openvino.tools.mo.utils.error import Error | ||
|
||
|
||
class EyeMXToEye(FrontReplacementPattern): | ||
""" | ||
This transformation converts MXEye operation (MXNet semantic) to Eye operation (OpenVINO semantic). | ||
Refer to the Op implementation for the operations semantics description. | ||
""" | ||
enabled = True | ||
|
||
def find_and_replace_pattern(self, graph: Graph): | ||
for mxeye in graph.get_op_nodes(op='MXEye'): | ||
# save the original node name to use it in the new Eye op instance | ||
original_name = mxeye.soft_get('name', mxeye.id) | ||
mxeye['name'] = original_name + '/to_be_removed' | ||
|
||
if not mxeye.has_valid('num_rows'): | ||
raise Error("MXEye should have valid ''num_rows'' attribute.") | ||
num_rows = mxeye.soft_get('num_rows') | ||
|
||
if not mxeye.has_valid('num_columns'): | ||
raise Error("MXEye should have valid ''num_columns'' attribute.") | ||
num_columns = mxeye.soft_get('num_columns') | ||
|
||
if not mxeye.has_valid('diagonal_index'): | ||
raise Error("MXEye should have valid ''diagonal_index'' attribute.") | ||
diagonal_index = mxeye.soft_get('diagonal_index') | ||
|
||
if not mxeye.has_valid('output_type'): | ||
raise Error("MXEye should have valid ''output_type'' attribute.") | ||
output_type = mxeye.soft_get('output_type') | ||
|
||
new_eye = create_op_with_const_inputs(graph, Eye, {0: int64_array(num_rows), | ||
1: int64_array(num_columns), | ||
2: int64_array(diagonal_index)}, | ||
{'name': original_name + '/Gathered', | ||
'output_type': output_type}) | ||
mxeye.out_port(0).get_connection().set_source(new_eye.out_port(0)) | ||
graph.remove_node(mxeye.id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from openvino.tools.mo.ops.eye import TFEye | ||
from openvino.tools.mo.front.extractor import FrontExtractorOp | ||
from openvino.tools.mo.front.tf.extractors.utils import tf_dtype_extractor | ||
|
||
|
||
class EyeExtractor(FrontExtractorOp): | ||
op = 'Eye' | ||
enabled = True | ||
|
||
@classmethod | ||
def extract(cls, node): | ||
attrs = { | ||
'output_type': tf_dtype_extractor(node.pb.attr["dtype"].type), | ||
} | ||
TFEye.update_node_stat(node, attrs) | ||
return cls.enabled |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from openvino.tools.mo.ops.const import Const | ||
from openvino.tools.mo.front.common.replacement import FrontReplacementPattern | ||
from openvino.tools.mo.graph.graph import Graph, rename_node | ||
from openvino.tools.mo.ops.eye import Eye | ||
from openvino.tools.mo.utils.error import Error | ||
|
||
|
||
class EyeTFToEye(FrontReplacementPattern): | ||
""" | ||
This transformation converts TFEye operation (TensorFlow semantic) to Eye operation (OpenVINO semantic). | ||
Refer to the Op implementation for the operations semantics description. | ||
""" | ||
enabled = True | ||
|
||
def find_and_replace_pattern(self, graph: Graph): | ||
for tfeye in graph.get_op_nodes(op='TFEye'): | ||
# save the original node name to use it in the new Eye op instance | ||
original_name = tfeye.soft_get('name', tfeye.id) | ||
tfeye['name'] = original_name + '/to_be_removed' | ||
|
||
if not tfeye.has_valid('output_type'): | ||
raise Error("TFEye should have valid ''output_type'' attribute.") | ||
output_type = tfeye.soft_get('output_type') | ||
|
||
new_eye = Eye(graph, {'output_type': output_type}).create_node() | ||
rename_node(new_eye, original_name) | ||
|
||
# num_rows | ||
tfeye.in_port(0).get_connection().set_destination(new_eye.in_port(0)) | ||
# num_columns | ||
if not tfeye.in_port(1).disconnected: | ||
tfeye.in_port(1).get_connection().set_destination(new_eye.in_port(1)) | ||
# batch_shape | ||
if not tfeye.in_port(2).disconnected: | ||
tfeye.in_port(2).get_connection().set_destination(new_eye.in_port(3)) | ||
|
||
diagonal_index = Const(graph, {'name': original_name + '/diagonal_index', | ||
'value': 0}).create_node() | ||
diagonal_index.out_port(0).connect(new_eye.in_port(2)) | ||
|
||
tfeye.out_port(0).get_connection().set_source(new_eye.out_port(0)) | ||
graph.remove_node(tfeye.id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
|
||
from openvino.tools.mo.graph.graph import Graph, Node | ||
from openvino.tools.mo.middle.passes.convert_data_type import np_data_type_to_destination_type | ||
from openvino.tools.mo.ops.op import Op | ||
from openvino.tools.mo.front.common.partial_infer.utils import dynamic_dimension, shape_array | ||
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined | ||
|
||
|
||
class Eye(Op): | ||
""" | ||
Eye operation that generates shift matrix or a batch of matrices. | ||
""" | ||
op = 'Eye' | ||
enabled = False | ||
in_ports_count = 4 | ||
|
||
def __init__(self, graph: Graph, attrs: dict): | ||
super().__init__(graph, { | ||
'type': self.op, | ||
'op': self.op, | ||
'version': 'opset9', | ||
'infer': self.infer, | ||
'in_ports_count': 4, | ||
'out_ports_count': 1, | ||
'type_infer': self.type_infer, | ||
'output_type': np.float32, | ||
}, attrs) | ||
|
||
def backend_attrs(self): | ||
return [('output_type', lambda node: np_data_type_to_destination_type(node.output_type))] | ||
|
||
@staticmethod | ||
def type_infer(node: Node): | ||
node.out_port(0).set_data_type(node['output_type']) | ||
|
||
@staticmethod | ||
def infer(node: Node): | ||
assert node.has_valid('output_type') | ||
|
||
connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()} | ||
assert len(connected_in_ports) >= 3 and all(p in connected_in_ports for p in [0, 1, 2]), \ | ||
"Eye should have at least 3 connected input port." \ | ||
"Got ports: `{}`.".format(connected_in_ports) | ||
|
||
num_rows_port = 0 | ||
num_columns_port = 1 | ||
diagonal_index_port = 2 | ||
batch_shape_port = 3 | ||
|
||
num_rows_shape = node.in_port(num_rows_port).data.get_shape() | ||
assert len(num_rows_shape) <= 1, \ | ||
'"num_rows" should be 1D tensor or scalar. Got: '.format(len(num_rows_shape)) | ||
num_rows = node.in_port(num_rows_port).data.get_value() | ||
if num_rows is None: | ||
num_rows = dynamic_dimension | ||
else: | ||
num_rows = np.array(num_rows).item() | ||
|
||
num_columns_shape = node.in_port(num_columns_port).data.get_shape() | ||
assert len(num_columns_shape) <= 1, \ | ||
'"num_columns" should be 1D tensor or scalar. Got: '.format(len(num_columns_shape)) | ||
num_columns = node.in_port(num_columns_port).data.get_value() | ||
if num_columns is None: | ||
num_columns = dynamic_dimension | ||
else: | ||
num_columns = np.array(num_columns).item() | ||
|
||
diagonal_index_shape = node.in_port(diagonal_index_port).data.get_shape() | ||
assert len(diagonal_index_shape) <= 1, \ | ||
'"diagonal_index" should be 1D tensor or scalar. Got: '.format(len(diagonal_index_shape)) | ||
diagonal_index = node.in_port(diagonal_index_port).data.get_value() | ||
|
||
if batch_shape_port in connected_in_ports: | ||
batch_shape_shape = node.in_port(batch_shape_port).data.get_shape() | ||
assert len(batch_shape_shape) == 1, \ | ||
'"batch_shape" should be 1D tensor. Got: '.format(len(batch_shape_shape)) | ||
batch_shape = node.in_port(batch_shape_port).data.get_value() | ||
if batch_shape is None: | ||
batch_shape = [dynamic_dimension] * batch_shape_shape[0] | ||
else: | ||
batch_shape = [] | ||
|
||
output_shape = [*batch_shape, num_rows, num_columns] | ||
node.out_port(0).data.set_shape(output_shape) | ||
|
||
if is_fully_defined(output_shape) and diagonal_index is not None: | ||
tile_shape = [*batch_shape, 1, 1] | ||
one_matrix = np.eye(num_rows, M=num_columns, k=diagonal_index, dtype=node.output_type) | ||
output_value = np.tile(one_matrix, tile_shape) | ||
node.out_port(0).data.set_value(shape_array(output_value)) | ||
|
||
|
||
class TFEye(Op): | ||
""" Eye operation that that generates shift matrix or a batch of matrices. | ||
Eye operation from TensorFlow has three inputs: row number, column number and batch shape | ||
""" | ||
op = 'TFEye' | ||
enabled = False | ||
|
||
def __init__(self, graph: Graph, attrs: dict): | ||
super().__init__(graph, { | ||
'type': None, | ||
'op': self.op, | ||
'infer': None, | ||
'in_ports_count': 3, | ||
'out_ports_count': 1, | ||
'output_type': np.float32, | ||
}, attrs) | ||
|
||
|
||
class MXEye(Op): | ||
""" Eye operation that that generates shift matrix or a batch of matrices. | ||
Eye operation from MXNet doesn't have inputs. Only attributes: row number, column number and diagonal index | ||
""" | ||
op = 'MXEye' | ||
enabled = False | ||
|
||
def __init__(self, graph: Graph, attrs: dict): | ||
super().__init__(graph, { | ||
'type': None, | ||
'op': self.op, | ||
'infer': None, | ||
'in_ports_count': 0, | ||
'out_ports_count': 1, | ||
'num_rows': 1, | ||
'num_columns': 1, | ||
'diagonal_index': 0, | ||
'output_type': np.float32, | ||
}, attrs) |
Oops, something went wrong.