@@ -1022,12 +1022,13 @@ def call(self, inputs, **kwargs):
1022
1022
raise ValueError (
1023
1023
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K .ndim (inputs )))
1024
1024
1025
+ n = len (inputs )
1025
1026
if self .bilinear_type == "all" :
1026
- p = [tf .multiply ( tf . tensordot (v_i , self .W , axes = (- 1 , 0 )), v_j )
1027
- for v_i , v_j in itertools .combinations (inputs , 2 )]
1027
+ vidots = [tf .tensordot (inputs [ i ] , self .W , axes = (- 1 , 0 )) for i in range ( n )]
1028
+ p = [ tf . multiply ( vidots [ i ], inputs [ j ]) for i , j in itertools .combinations (range ( n ) , 2 )]
1028
1029
elif self .bilinear_type == "each" :
1029
- p = [tf .multiply ( tf . tensordot (inputs [i ], self .W_list [i ], axes = (- 1 , 0 )), inputs [ j ])
1030
- for i , j in itertools .combinations (range (len ( inputs ) ), 2 )]
1030
+ vidots = [tf .tensordot (inputs [i ], self .W_list [i ], axes = (- 1 , 0 )) for i in range ( n - 1 )]
1031
+ p = [ tf . multiply ( vidots [ i ], inputs [ j ]) for i , j in itertools .combinations (range (n ), 2 )]
1031
1032
elif self .bilinear_type == "interaction" :
1032
1033
p = [tf .multiply (tf .tensordot (v [0 ], w , axes = (- 1 , 0 )), v [1 ])
1033
1034
for v , w in zip (itertools .combinations (inputs , 2 ), self .W_list )]
0 commit comments