Skip to content

Commit

Permalink
Fixed bug in explicit_broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Jan 6, 2023
1 parent 0dbac04 commit 85f5007
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,14 @@ def explicit_broadcast(
graph_node_input_shape1, graph_node_input_shape2 = graph_node_input_shape2, graph_node_input_shape1
swapped += 1

# Skip subsequent processing in the following patterns.
# const_or_var_1: [1,1,5000]
# const_or_var_2: [5000]
if len(const_or_var_1.shape) >= 1 \
and len(const_or_var_2.shape) == 1 \
and const_or_var_1.shape[-1] == const_or_var_2.shape[-1]:
return const_or_var_1, const_or_var_2

"""
UnSqueeze 1 at the beginning of const_or_var_2_shape until const_or_var_1.shape
and const_or_var_2.shape have the same rank
Expand Down

0 comments on commit 85f5007

Please sign in to comment.