Skip to content

Commit

Permalink
Quantizer propagation in TF (openvinotoolkit#803)
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor authored Jul 12, 2021
1 parent 1e274f6 commit 1e5c0c8
Show file tree
Hide file tree
Showing 54 changed files with 1,058 additions and 428 deletions.
19 changes: 15 additions & 4 deletions nncf/common/graph/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Dict
from typing import Optional
from typing import List
from typing import Tuple
Expand Down Expand Up @@ -188,17 +189,20 @@ def join_patterns(self, other: 'GraphPattern',
last node of self's graph and first node of other's graph,
which are found by nx.lexicographical_topological_sort().
If other starts from a node with the PATTERN_INPUT_NODE_TYPE type, the input node of the other will be
discarded from the final pattern.
:param other: GraphPattern that will be added
:param edges: List of edges between self and other graphs.
Edges must begin at self and finish at other.
"""
# Unite nodes
other_graph = other.graph
other_graph_copy = copy.deepcopy(other.graph)
node_mapping = {}
for node in other_graph.nodes:
node_mapping[node] = self._node_counter
for node_key in other_graph_copy.nodes:
node_mapping[node_key] = self._node_counter
self._node_counter += 1
other_graph_copy = nx.relabel_nodes(other.graph, node_mapping, copy=True)
other_graph_copy = nx.relabel_nodes(other_graph_copy, node_mapping, copy=True)

saved_graph = copy.deepcopy(self._graph)
self._graph = nx.union(saved_graph, other_graph_copy)
Expand Down Expand Up @@ -229,3 +233,10 @@ def get_weakly_connected_subgraphs(self) -> List[nx.DiGraph]:

def dump_graph(self, path: str) -> None:
nx.drawing.nx_pydot.write_dot(self._graph, path)


def merge_two_types_of_operations(first_op: Dict, second_op: Dict, label: str) -> Dict:
res = {'type': first_op['type']}
res['type'].extend(second_op['type'])
res['label'] = label
return res
43 changes: 39 additions & 4 deletions nncf/common/insertion_point_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import networkx as nx

from nncf.common.graph import Dtype
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNodeName
from nncf.common.graph.graph_matching import find_subgraphs_matching_pattern
Expand Down Expand Up @@ -60,6 +61,7 @@ class InsertionPointGraph(nx.DiGraph):
ASSOCIATED_IP_NODE_KEYS_NODE_ATTR = 'associated_ip_node_keys'
IS_MERGED_NODE_ATTR = 'is_merged'
MERGED_NNCF_NODE_LIST_NODE_ATTR = 'merged_node_list'
IS_INTEGER_PATH_EDGE_ATTR = 'is_integer'

PRE_HOOK_ID_PREFIX = 'PRE HOOK ' # NB: Do not use colon (':') in node keys! Causes trouble for .dot file export.
POST_HOOK_ID_PREFIX = 'POST HOOK '
Expand All @@ -80,6 +82,8 @@ def __init__(self, nncf_graph: NNCFGraph, weight_modifiable_node_names: List[NNC
If left unspecified, every node in `nncf_graph` will be allowed to have a single post-hook for its output
(post-hooking separate tensors in an operation's output is not currently supported)
"""
#pylint:disable=too-many-branches
#pylint:disable=too-many-statements
super().__init__()
self._base_nx_graph = deepcopy(nncf_graph.get_nx_graph_copy())
if weight_modifiable_node_names is None:
Expand Down Expand Up @@ -115,8 +119,10 @@ def __init__(self, nncf_graph: NNCFGraph, weight_modifiable_node_names: List[NNC
INPUT_PORT_ID = "input_port_id"
for edge in self._base_nx_graph.edges:
input_port_id = self._base_nx_graph.edges[edge][NNCFGraph.INPUT_PORT_ID_EDGE_ATTR]
dtype = self._base_nx_graph.edges[edge][NNCFGraph.DTYPE_EDGE_ATTR]
from_node, to_node = edge
attrs = {INPUT_PORT_ID: input_port_id}
attrs = {INPUT_PORT_ID: input_port_id,
self.IS_INTEGER_PATH_EDGE_ATTR: dtype is Dtype.INTEGER}
self.add_edge(from_node, to_node, **attrs)

node_keys_working_set = [deepcopy(node_key) for node_key in nx.lexicographical_topological_sort(self)]
Expand All @@ -138,6 +144,7 @@ def __init__(self, nncf_graph: NNCFGraph, weight_modifiable_node_names: List[NNC
input_port_id_vs_edge = {self.edges[edge][INPUT_PORT_ID]: edge for edge in in_edges}
for pre_hook_point in pre_hook_ips:
edge = input_port_id_vs_edge[pre_hook_point.input_port_id]
original_edge_attrs = self.edges[edge]
from_node_key, to_node_key = edge
ip_node_key = self.get_pre_hook_node_key(str(operator_node_key), pre_hook_point.input_port_id)

Expand All @@ -149,8 +156,8 @@ def __init__(self, nncf_graph: NNCFGraph, weight_modifiable_node_names: List[NNC
self.add_node(ip_node_key, **pre_hook_ip_attrs)

self.remove_edge(from_node_key, to_node_key)
self.add_edge(from_node_key, ip_node_key)
self.add_edge(ip_node_key, operator_node_key)
self.add_edge(from_node_key, ip_node_key, **original_edge_attrs)
self.add_edge(ip_node_key, operator_node_key, **original_edge_attrs)
operator_node[InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR].add(ip_node_key)

if original_node.node_name in target_node_name_vs_post_hook_ips:
Expand All @@ -164,19 +171,47 @@ def __init__(self, nncf_graph: NNCFGraph, weight_modifiable_node_names: List[NNC
ip_node_key = self.get_post_hook_node_key(str(operator_node_key))
self.add_node(ip_node_key, **post_hook_ip_attrs)
out_edges = list(self.out_edges(operator_node_key))
has_integer_outputs = False
for out_edge in out_edges:
# Need to preserve original edge attributes in order not to lose
# input port ID information
original_edge_attrs = self.edges[out_edge]
from_node_key, to_node_key = out_edge
self.remove_edge(from_node_key, to_node_key)
self.add_edge(ip_node_key, to_node_key, **original_edge_attrs)
if original_edge_attrs[self.IS_INTEGER_PATH_EDGE_ATTR]:
has_integer_outputs = True

# TODO (vshampor): introduce separate insertion points for operator outputs if
# the outputs are semantically different
self.add_edge(operator_node_key, ip_node_key)

# TODO (vshampor): in multi-output case, some outputs may be integer and some float;
# need to switch to using output ports to cover this correctly. For safety, mark
# the edge from op to post-hook as integer if at least one output edge of the op was integer
is_integer_attrs = {InsertionPointGraph.IS_INTEGER_PATH_EDGE_ATTR: has_integer_outputs}
self.add_edge(operator_node_key, ip_node_key, **is_integer_attrs)
operator_node = self.nodes[operator_node_key]
operator_node[InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR].add(ip_node_key)

for edge in self.edges:
# Mark all edges from post-hook to pre-hook as integer if at least one was integer.
# Until output ports are ready, the post-hook for output will treat op as having a single
# tensor output. In multi-output case when some of tensors are integer, need to make
# sure that the propagation won't happen from a pre-hook of the op consuming the floating part
# of the output into the post-hook of the operation that produces both int and float tensors.
from_node_key, to_node_key = edge
from_node = self.nodes[from_node_key]
to_node = self.nodes[to_node_key]
if from_node[self.NODE_TYPE_NODE_ATTR] is InsertionPointGraphNodeType.POST_HOOK and \
to_node[self.NODE_TYPE_NODE_ATTR] is InsertionPointGraphNodeType.PRE_HOOK:
post_hook_has_integer_outputs = False
for follower_node_key in self.successors(from_node_key):
if self.edges[from_node_key, follower_node_key][self.IS_INTEGER_PATH_EDGE_ATTR]:
post_hook_has_integer_outputs = True
if post_hook_has_integer_outputs:
for follower_node_key in self.successors(from_node_key):
self.edges[from_node_key, follower_node_key][self.IS_INTEGER_PATH_EDGE_ATTR] = True

@property
def weight_modifiable_node_names(self) -> List[NNCFNodeName]:
return self._weight_modifiable_node_names
Expand Down
16 changes: 11 additions & 5 deletions nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class QuantizerPropagationStateGraph(nx.DiGraph):
IS_IN_IGNORED_SCOPES = "is_ignored"
IS_MERGED_NODE_ATTR = "is_merged"
MERGED_NNCF_NODE_LIST_NODE_ATTR = "merged_node_list"
IS_INTEGER_PATH_EDGE_ATTR = "is_integer"
BARRIER_NODE_KEY_POSTFIX = "BARRIER"

def __init__(self, ip_graph: InsertionPointGraph,
Expand Down Expand Up @@ -148,6 +149,8 @@ def __init__(self, ip_graph: InsertionPointGraph,

for from_node, to_node, edge_data in ip_graph.edges(data=True):
edge_data[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
is_integer = edge_data.pop(InsertionPointGraph.IS_INTEGER_PATH_EDGE_ATTR)
edge_data[self.IS_INTEGER_PATH_EDGE_ATTR] = is_integer
self.add_edge(from_node, to_node, **edge_data)

for barred_node_key in self.ignored_node_keys + iteration_scope_node_keys:
Expand All @@ -168,11 +171,11 @@ def _add_barrier_after_node(self, node_key: str):
barrier_node_key = self.get_barrier_node_key(node_key)
self.add_node(barrier_node_key, **qpg_node_barrier)

edge_attr = {QuantizerPropagationStateGraph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR: []}
next_node_keys = list(self.succ[node_key].keys())
for next_node_key in next_node_keys:
self.add_edge(node_key, barrier_node_key, **edge_attr)
self.add_edge(barrier_node_key, next_node_key, **edge_attr)
edge_attrs = self.edges[node_key, next_node_key]
self.add_edge(node_key, barrier_node_key, **edge_attrs)
self.add_edge(barrier_node_key, next_node_key, **edge_attrs)
self.remove_edge(node_key, next_node_key)

@staticmethod
Expand All @@ -187,8 +190,8 @@ def ipg_node_type_to_qpsg_node_type(ipg_node_type: InsertionPointGraphNodeType)
raise RuntimeError("Invalid insertion point graph node type.")

@staticmethod
def get_barrier_node_key(node_key: str):
return QuantizerPropagationStateGraph.BARRIER_NODE_KEY_POSTFIX + node_key
def get_barrier_node_key(node_key: str) -> str:
return f"{QuantizerPropagationStateGraph.BARRIER_NODE_KEY_POSTFIX} {node_key}"


def mark_act_quantizer_as_dependent_on_weights(self, pq: PropagatingQuantizer, operator_node_key: str):
Expand Down Expand Up @@ -787,6 +790,9 @@ def get_visualized_graph(self):
if affecting_quantizers:
label = ", ".join([str(pq.id) for pq in affecting_quantizers])
attrs = {"color": "blue", "label": label}
is_integer_path = edge[QuantizerPropagationStateGraph.IS_INTEGER_PATH_EDGE_ATTR]
if is_integer_path:
attrs = {"color": "violet", "style": "bold"}
out_graph.add_edge(u, v, **attrs)

for gid, group_pq_node_keys in unified_scale_group_vs_pq_node_id_dict.items():
Expand Down
Loading

0 comments on commit 1e5c0c8

Please sign in to comment.