@@ -1042,3 +1042,123 @@ def get_config(self, ):
1042
1042
config = {'bilinear_type' : self .bilinear_type , 'seed' : self .seed }
1043
1043
base_config = super (BilinearInteraction , self ).get_config ()
1044
1044
return dict (list (base_config .items ()) + list (config .items ()))
1045
+
1046
+
1047
+ class FieldWiseBiInteraction (Layer ):
1048
+ """Field-Wise Bi-Interaction Layer used in FLEN,compress the
1049
+ pairwise element-wise product of features into one single vector.
1050
+
1051
+ Input shape
1052
+ - A list of 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
1053
+
1054
+ Output shape
1055
+ - 2D tensor with shape: ``(batch_size,embedding_size)``.
1056
+
1057
+ Arguments
1058
+ - **use_bias** : Boolean, if use bias.
1059
+ - **l2_reg** : Float, l2 regularization coefficient.
1060
+ - **seed** : A Python integer to use as random seed.
1061
+
1062
+ References
1063
+ [1] hen W, Zhan L, Ci Y, Lin C https://arxiv.org/pdf/1911.04690
1064
+ """
1065
+ def __init__ (self , l2_reg = 1e-5 , seed = 1024 , ** kwargs ):
1066
+
1067
+ self .l2_reg = l2_reg
1068
+ self .seed = seed
1069
+
1070
+ super (FieldWiseBiInteraction , self ).__init__ (** kwargs )
1071
+
1072
+ def build (self , input_shape ):
1073
+
1074
+ if not isinstance (input_shape , list ) or len (input_shape ) < 2 :
1075
+ raise ValueError (
1076
+ 'A `Field-Wise Bi-Interaction` layer should be called '
1077
+ 'on a list of at least 2 inputs' )
1078
+
1079
+ self .num_fields = len (input_shape )
1080
+ embedding_size = input_shape [0 ][- 1 ]
1081
+
1082
+ self .kernel_inter = self .add_weight (
1083
+ name = 'kernel_inter' ,
1084
+ shape = (int (self .num_fields * (self .num_fields - 1 ) / 2 ), 1 ),
1085
+ initializer = glorot_normal (seed = self .seed ),
1086
+ regularizer = l2 (self .l2_reg ),
1087
+ trainable = True )
1088
+ self .bias_inter = self .add_weight (name = 'bias_inter' ,
1089
+ shape = (embedding_size ),
1090
+ initializer = Zeros (),
1091
+ trainable = True )
1092
+ self .kernel_intra = self .add_weight (
1093
+ name = 'kernel_intra' ,
1094
+ shape = (self .num_fields , 1 ),
1095
+ initializer = glorot_normal (seed = self .seed ),
1096
+ regularizer = l2 (self .l2_reg ),
1097
+ trainable = True )
1098
+ self .bias_intra = self .add_weight (name = 'bias_intra' ,
1099
+ shape = (embedding_size ),
1100
+ initializer = Zeros (),
1101
+ trainable = True )
1102
+
1103
+ super (FieldWiseBiInteraction ,
1104
+ self ).build (input_shape ) # Be sure to call this somewhere!
1105
+
1106
+ def call (self , inputs , ** kwargs ):
1107
+
1108
+ if K .ndim (inputs [0 ]) != 3 :
1109
+ raise ValueError (
1110
+ "Unexpected inputs dimensions %d, expect to be 3 dimensions" %
1111
+ (K .ndim (inputs )))
1112
+
1113
+ field_wise_embeds_list = inputs
1114
+
1115
+ # MF module
1116
+ field_wise_vectors = tf .concat ([
1117
+ reduce_sum (field_i_vectors , axis = 1 , keep_dims = True )
1118
+ for field_i_vectors in field_wise_embeds_list
1119
+ ], 1 )
1120
+
1121
+ left = []
1122
+ right = []
1123
+ for i in range (self .num_fields ):
1124
+ for j in range (i + 1 , self .num_fields ):
1125
+ left .append (i )
1126
+ right .append (j )
1127
+
1128
+ embeddings_left = tf .gather (params = field_wise_vectors ,
1129
+ indices = left ,
1130
+ axis = 1 )
1131
+ embeddings_right = tf .gather (params = field_wise_vectors ,
1132
+ indices = right ,
1133
+ axis = 1 )
1134
+
1135
+ embeddings_prod = embeddings_left * embeddings_right
1136
+ field_weighted_embedding = embeddings_prod * self .kernel_inter
1137
+ h_mf = reduce_sum (field_weighted_embedding , axis = 1 )
1138
+ h_mf = tf .nn .bias_add (h_mf , self .bias_inter )
1139
+
1140
+ # FM module
1141
+ square_of_sum_list = [
1142
+ tf .square (reduce_sum (field_i_vectors , axis = 1 , keep_dims = True ))
1143
+ for field_i_vectors in field_wise_embeds_list
1144
+ ]
1145
+ sum_of_square_list = [
1146
+ reduce_sum (field_i_vectors * field_i_vectors ,
1147
+ axis = 1 ,
1148
+ keep_dims = True )
1149
+ for field_i_vectors in field_wise_embeds_list
1150
+ ]
1151
+
1152
+ field_fm = tf .concat ([
1153
+ square_of_sum - sum_of_square for square_of_sum , sum_of_square in
1154
+ zip (square_of_sum_list , sum_of_square_list )
1155
+ ], 1 )
1156
+
1157
+ h_fm = reduce_sum (field_fm * self .kernel_intra , axis = 1 )
1158
+
1159
+ h_fm = tf .nn .bias_add (h_fm , self .bias_intra )
1160
+
1161
+ return h_mf + h_fm
1162
+
1163
+ def compute_output_shape (self , input_shape ):
1164
+ return (None , input_shape [0 ][- 1 ])
0 commit comments