Skip to content

Commit

Permalink
add function for splitting all fused ops
Browse files Browse the repository at this point in the history
  • Loading branch information
patlevin committed Mar 23, 2022
1 parent 768ab72 commit 3881888
Showing 1 changed file with 32 additions and 5 deletions.
37 changes: 32 additions & 5 deletions tfjs_graph_converter/convert_prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def _split_fused_op(node: util.NodeDef,
# FusedBatchNorm and Squeeze are not relevant for inference and thus never
# present in (optimised) frozen graphs generated by tfjs converter.
# This leaves us with Conv2D|MatMul + BiasAdd + <Activation> as the only
# remaining possible variants, since Conv2D + BiasAdd doesn't need to be
# split.
# remaining possible variants.
#
# For compatibility reasons with quantised TFLite models, we optionally
# split Conv2D + BiasAdd as well.
#
# We return [Conv2D|MatMul, BiasAdd|BiasAddV1, <Activation>].
# Unsupported <Activation>-nodes will be dealt with in a separate step
Expand All @@ -43,9 +45,14 @@ def node_name(node_index):
bias_add = util.make_op_node(fused_ops[0], [fused_op, inputs[2]],
node_name(2))
bias_add = util.copy_op_attrs(source=node, target=bias_add)
activation = util.make_op_node(fused_ops[1], [bias_add] + inputs[3:],
node_name(3))
return [fused_op, bias_add, activation]

have_activation = len(fused_ops) > 1
if have_activation:
activation = util.make_op_node(fused_ops[1], [bias_add] + inputs[3:],
node_name(3))
return [fused_op, bias_add, activation]
else:
return [fused_op, bias_add]


def _split_prelu(node: util.NodeDef,
Expand Down Expand Up @@ -114,3 +121,23 @@ def replace_prelu(input_graph_def: util.GraphDef) -> util.GraphDef:
def _predicate(node): return node.op == 'Prelu'
return util.replace_matching_nodes(input_graph_def, _predicate,
_split_prelu)


def split_all_fused_ops(input_graph_def: util.GraphDef) -> util.GraphDef:
"""
Split all fused-operation nodes in the graph into individual operations.
This enables further conversion into formats that don't support fused
operations (e.g. TFLite without Flex enabled).
Args:
input_graph_def: TF graph definition to examine
Returns:
Updated copy of the input graph with matching nodes replaced by
individual operations
"""
def _predicate(node):
return util.is_fused_conv2d(node) or util.is_fused_matmul(node)
return util.replace_matching_nodes(input_graph_def=input_graph_def,
predicate=_predicate,
transform=_split_fused_op)

0 comments on commit 3881888

Please sign in to comment.