Skip to content

Commit 2660905

Browse files
committed
speed up BilinearInteraction
1 parent 66d173e commit 2660905

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

deepctr/layers/interaction.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1022,12 +1022,13 @@ def call(self, inputs, **kwargs):
10221022
raise ValueError(
10231023
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs)))
10241024

1025+
n = len(inputs)
10251026
if self.bilinear_type == "all":
1026-
p = [tf.multiply(tf.tensordot(v_i, self.W, axes=(-1, 0)), v_j)
1027-
for v_i, v_j in itertools.combinations(inputs, 2)]
1027+
vidots = [tf.tensordot(inputs[i], self.W, axes=(-1, 0)) for i in range(n)]
1028+
p = [tf.multiply(vidots[i], inputs[j]) for i, j in itertools.combinations(range(n), 2)]
10281029
elif self.bilinear_type == "each":
1029-
p = [tf.multiply(tf.tensordot(inputs[i], self.W_list[i], axes=(-1, 0)), inputs[j])
1030-
for i, j in itertools.combinations(range(len(inputs)), 2)]
1030+
vidots = [tf.tensordot(inputs[i], self.W_list[i], axes=(-1, 0)) for i in range(n - 1)]
1031+
p = [tf.multiply(vidots[i], inputs[j]) for i, j in itertools.combinations(range(n), 2)]
10311032
elif self.bilinear_type == "interaction":
10321033
p = [tf.multiply(tf.tensordot(v[0], w, axes=(-1, 0)), v[1])
10331034
for v, w in zip(itertools.combinations(inputs, 2), self.W_list)]

0 commit comments

Comments
 (0)