-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce ONNX GraphConverter to get NNCFGraph from ONNX (#1070)
### Changes As stated in the title ### Reason for changes Implement ONNX Post-Training Quantization ### Related tickets 75422 ### Tests Add test comparing NNCFGraph built for two ONNX synthetic models.
- Loading branch information
Showing
24 changed files
with
1,134 additions
and
321 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from typing import List | ||
|
||
from nncf.common.graph.operator_metatypes import NOOP_METATYPES | ||
from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES | ||
|
||
from nncf.common.graph.operator_metatypes import OperatorMetatype | ||
from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry | ||
|
||
ONNX_OPERATION_METATYPES = OperatorMetatypeRegistry('onnx_operator_metatypes') | ||
|
||
|
||
class ONNXOpMetatype(OperatorMetatype): | ||
op_names = [] # type: List[str] | ||
|
||
@classmethod | ||
def get_all_aliases(cls) -> List[str]: | ||
return cls.op_names | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
@NOOP_METATYPES.register() | ||
class ONNXLayerNoopMetatype(ONNXOpMetatype): | ||
name = 'noop' | ||
|
||
@classmethod | ||
def get_all_aliases(cls) -> List[str]: | ||
return [cls.name] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
@INPUT_NOOP_METATYPES.register() | ||
class ONNXInputLayerMetatype(ONNXOpMetatype): | ||
name = 'InputLayer' | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class ConvolutionMetatype(ONNXOpMetatype): | ||
name = 'ConvOp' | ||
op_names = ['Conv'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class LinearMetatype(ONNXOpMetatype): | ||
name = 'LinearOp' | ||
op_names = ['Gemm'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class ReluMetatype(ONNXOpMetatype): | ||
name = 'ReluOp' | ||
op_names = ['Relu', 'Clip'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class SigmoidMetatype(ONNXOpMetatype): | ||
name = 'SigmoidOp' | ||
op_names = ['Sigmoid'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class GlobalAveragePoolMetatype(ONNXOpMetatype): | ||
name = 'GlobalAveragePoolOp' | ||
op_names = ['GlobalAveragePool'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class MaxPoolMetatype(ONNXOpMetatype): | ||
name = 'MaxPoolOp' | ||
op_names = ['MaxPool'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class ConstantMetatype(ONNXOpMetatype): | ||
name = 'ConstantOp' | ||
op_names = ['Constant'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class AddLayerMetatype(ONNXOpMetatype): | ||
name = 'AddOp' | ||
op_names = ['Add'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class MulLayerMetatype(ONNXOpMetatype): | ||
name = 'MulOp' | ||
op_names = ['Mul'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class ConcatLayerMetatype(ONNXOpMetatype): | ||
name = 'ConcatOp' | ||
op_names = ['Concat'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class BatchNormMetatype(ONNXOpMetatype): | ||
name = 'BatchNormalizationOp' | ||
op_names = ['BatchNormalization'] | ||
|
||
|
||
@ONNX_OPERATION_METATYPES.register() | ||
class ResizeMetatype(ONNXOpMetatype): | ||
name = 'ResizeOp' | ||
op_names = ['Resize'] | ||
|
||
|
||
GENERAL_WEIGHT_LAYER_METATYPES = [ConvolutionMetatype, | ||
LinearMetatype] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from collections import defaultdict | ||
|
||
from onnx import ModelProto # pylint: disable=no-name-in-module | ||
|
||
from nncf.common.graph import NNCFGraph | ||
from nncf.common.graph.definitions import NNCFGraphNodeType | ||
from nncf.common.graph.layer_attributes import Dtype | ||
from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME | ||
from nncf.common.graph.definitions import MODEL_OUTPUT_OP_NAME | ||
|
||
from nncf.experimental.onnx.graph.onnx_graph import ONNXGraph | ||
from nncf.experimental.onnx.graph.metatypes.onnx_ops import ONNX_OPERATION_METATYPES | ||
from nncf.experimental.onnx.graph.metatypes.onnx_ops import ConstantMetatype | ||
|
||
|
||
class GraphConverter: | ||
""" | ||
Builds the NNCFGraph from an ONNX model | ||
""" | ||
|
||
@staticmethod | ||
def create_nncf_graph(onnx_model: ModelProto) -> NNCFGraph: | ||
""" | ||
Adds all ONNX nodes from 'onnx_model' and then adds thr special input_nodes and output_nodes. | ||
""" | ||
|
||
nncf_graph = NNCFGraph() | ||
onnx_graph = ONNXGraph(onnx_model) | ||
for node in onnx_graph.get_all_nodes(): | ||
node_name = node.name | ||
node_type = node.op_type | ||
metatype = ONNX_OPERATION_METATYPES.get_operator_metatype_by_op_name(node_type) | ||
if metatype == ConstantMetatype: # We don't need to quantize Constants | ||
continue | ||
nncf_graph.add_nncf_node(node_name=node_name, | ||
node_type=node_type, | ||
node_metatype=metatype, | ||
layer_attributes=None) | ||
input_counter = defaultdict(int) | ||
output_counter = defaultdict(int) | ||
for output_node in nncf_graph.get_all_nodes(): | ||
output_node_id = output_node.node_id | ||
outputs = onnx_graph.get_node_edges(output_node.node_name)['output'] | ||
for output in outputs: | ||
nodes = onnx_graph.get_nodes_by_input(output) | ||
shape = onnx_graph.get_edge_shape(output) | ||
onnx_dtype = onnx_graph.get_edge_dtype(output) | ||
nncf_dtype = GraphConverter.convert_onnx_dtype_to_nncf_dtype(onnx_dtype) | ||
for in_node in nodes: | ||
in_node_id = nncf_graph.get_node_by_name(in_node.name).node_id | ||
input_counter[in_node_id] += 1 | ||
output_counter[output_node_id] += 1 | ||
nncf_graph.add_edge_between_nncf_nodes( | ||
from_node_id=output_node_id, | ||
to_node_id=in_node_id, | ||
tensor_shape=shape, | ||
input_port_id=input_counter[in_node_id], | ||
output_port_id=output_counter[output_node_id], | ||
dtype=Dtype(nncf_dtype) | ||
) | ||
# Add Input Nodes | ||
for i, _input in enumerate(onnx_graph.get_model_inputs()): | ||
input_shape = onnx_graph.get_tensor_shape(_input) | ||
input_node = nncf_graph.add_nncf_node(node_name=MODEL_INPUT_OP_NAME + '_' + str(i), | ||
node_type=NNCFGraphNodeType.INPUT_NODE, | ||
node_metatype=ONNX_OPERATION_METATYPES. | ||
get_operator_metatype_by_op_name( | ||
NNCFGraphNodeType.INPUT_NODE), | ||
layer_attributes=None) | ||
input_name = _input.name | ||
to_nodes = onnx_graph.get_nodes_by_input(input_name) | ||
for node in to_nodes: | ||
in_node_id = input_node.node_id | ||
to_node_id = nncf_graph.get_node_by_name(node.name).node_id | ||
input_counter[in_node_id] += 1 | ||
output_counter[to_node_id] += 1 | ||
onnx_dtype = onnx_graph.get_edge_dtype(input_name) | ||
nncf_dtype = GraphConverter.convert_onnx_dtype_to_nncf_dtype(onnx_dtype) | ||
nncf_graph.add_edge_between_nncf_nodes( | ||
from_node_id=input_node.node_id, | ||
to_node_id=to_node_id, | ||
tensor_shape=input_shape, | ||
input_port_id=input_counter[in_node_id], | ||
output_port_id=output_counter[to_node_id], | ||
dtype=nncf_dtype | ||
) | ||
# Add Output Nodes | ||
for i, _output in enumerate(onnx_graph.get_model_outputs()): | ||
output_shape = onnx_graph.get_tensor_shape(_output) | ||
output_node = nncf_graph.add_nncf_node(node_name=MODEL_OUTPUT_OP_NAME + '_' + str(i), | ||
node_type=NNCFGraphNodeType.OUTPUT_NODE, | ||
node_metatype=ONNX_OPERATION_METATYPES. | ||
get_operator_metatype_by_op_name( | ||
NNCFGraphNodeType.OUTPUT_NODE), | ||
layer_attributes=None) | ||
|
||
output_name = _output.name | ||
to_nodes = onnx_graph.get_nodes_by_output(output_name) | ||
for node in to_nodes: | ||
out_node_id = output_node.node_id | ||
to_node_id = nncf_graph.get_node_by_name(node.name).node_id | ||
input_counter[out_node_id] += 1 | ||
output_counter[to_node_id] += 1 | ||
onnx_dtype = onnx_graph.get_edge_dtype(output_name) | ||
nncf_dtype = GraphConverter.convert_onnx_dtype_to_nncf_dtype(onnx_dtype) | ||
nncf_graph.add_edge_between_nncf_nodes( | ||
from_node_id=to_node_id, | ||
to_node_id=output_node.node_id, | ||
tensor_shape=output_shape, | ||
input_port_id=input_counter[out_node_id], | ||
output_port_id=output_counter[to_node_id], | ||
dtype=nncf_dtype | ||
) | ||
|
||
return nncf_graph | ||
|
||
@staticmethod | ||
def convert_onnx_dtype_to_nncf_dtype(onnx_dtype: str) -> Dtype: | ||
conversation_map = { | ||
"FLOAT": "float", | ||
"FLOAT16": "float", | ||
"BFLOAT16": "float", | ||
"DOUBLE": "float", | ||
} | ||
return Dtype(conversation_map.get(onnx_dtype, 'int')) |
Oops, something went wrong.