Skip to content

Commit 8c52514

Browse files
committed
v1
1 parent 4a7675c commit 8c52514

File tree

4 files changed

+221
-3
lines changed

4 files changed

+221
-3
lines changed

deepctr/layers/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from .core import DNN, LocalActivationUnit, PredictionLayer
55
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet,
66
InnerProductLayer, InteractingLayer,
7-
OutterProductLayer, FGCNNLayer,SENETLayer,BilinearInteraction)
7+
OutterProductLayer, FGCNNLayer,SENETLayer,BilinearInteraction,
8+
FieldWiseBiInteraction)
89
from .normalization import LayerNormalization
910
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
1011
KMaxPooling, SequencePoolingLayer,WeightedSequenceLayer,
@@ -39,5 +40,6 @@
3940
'SENETLayer':SENETLayer,
4041
'BilinearInteraction':BilinearInteraction,
4142
'WeightedSequenceLayer':WeightedSequenceLayer,
42-
'Add':Add
43+
'Add':Add,
44+
'FieldWiseBiInteraction':FieldWiseBiInteraction
4345
}

deepctr/layers/interaction.py

+120
Original file line numberDiff line numberDiff line change
@@ -1042,3 +1042,123 @@ def get_config(self, ):
10421042
config = {'bilinear_type': self.bilinear_type, 'seed': self.seed}
10431043
base_config = super(BilinearInteraction, self).get_config()
10441044
return dict(list(base_config.items()) + list(config.items()))
1045+
1046+
1047+
class FieldWiseBiInteraction(Layer):
1048+
"""Field-Wise Bi-Interaction Layer used in FLEN,compress the
1049+
pairwise element-wise product of features into one single vector.
1050+
1051+
Input shape
1052+
- A list of 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
1053+
1054+
Output shape
1055+
- 2D tensor with shape: ``(batch_size,embedding_size)``.
1056+
1057+
Arguments
1058+
- **use_bias** : Boolean, if use bias.
1059+
- **l2_reg** : Float, l2 regularization coefficient.
1060+
- **seed** : A Python integer to use as random seed.
1061+
1062+
References
1063+
[1] hen W, Zhan L, Ci Y, Lin C https://arxiv.org/pdf/1911.04690
1064+
"""
1065+
def __init__(self, l2_reg=1e-5, seed=1024, **kwargs):
1066+
1067+
self.l2_reg = l2_reg
1068+
self.seed = seed
1069+
1070+
super(FieldWiseBiInteraction, self).__init__(**kwargs)
1071+
1072+
def build(self, input_shape):
1073+
1074+
if not isinstance(input_shape, list) or len(input_shape) < 2:
1075+
raise ValueError(
1076+
'A `Field-Wise Bi-Interaction` layer should be called '
1077+
'on a list of at least 2 inputs')
1078+
1079+
self.num_fields = len(input_shape)
1080+
embedding_size = input_shape[0][-1]
1081+
1082+
self.kernel_inter = self.add_weight(
1083+
name='kernel_inter',
1084+
shape=(int(self.num_fields * (self.num_fields - 1) / 2), 1),
1085+
initializer=glorot_normal(seed=self.seed),
1086+
regularizer=l2(self.l2_reg),
1087+
trainable=True)
1088+
self.bias_inter = self.add_weight(name='bias_inter',
1089+
shape=(embedding_size),
1090+
initializer=Zeros(),
1091+
trainable=True)
1092+
self.kernel_intra = self.add_weight(
1093+
name='kernel_intra',
1094+
shape=(self.num_fields, 1),
1095+
initializer=glorot_normal(seed=self.seed),
1096+
regularizer=l2(self.l2_reg),
1097+
trainable=True)
1098+
self.bias_intra = self.add_weight(name='bias_intra',
1099+
shape=(embedding_size),
1100+
initializer=Zeros(),
1101+
trainable=True)
1102+
1103+
super(FieldWiseBiInteraction,
1104+
self).build(input_shape) # Be sure to call this somewhere!
1105+
1106+
def call(self, inputs, **kwargs):
1107+
1108+
if K.ndim(inputs[0]) != 3:
1109+
raise ValueError(
1110+
"Unexpected inputs dimensions %d, expect to be 3 dimensions" %
1111+
(K.ndim(inputs)))
1112+
1113+
field_wise_embeds_list = inputs
1114+
1115+
# MF module
1116+
field_wise_vectors = tf.concat([
1117+
reduce_sum(field_i_vectors, axis=1, keep_dims=True)
1118+
for field_i_vectors in field_wise_embeds_list
1119+
], 1)
1120+
1121+
left = []
1122+
right = []
1123+
for i in range(self.num_fields):
1124+
for j in range(i + 1, self.num_fields):
1125+
left.append(i)
1126+
right.append(j)
1127+
1128+
embeddings_left = tf.gather(params=field_wise_vectors,
1129+
indices=left,
1130+
axis=1)
1131+
embeddings_right = tf.gather(params=field_wise_vectors,
1132+
indices=right,
1133+
axis=1)
1134+
1135+
embeddings_prod = embeddings_left * embeddings_right
1136+
field_weighted_embedding = embeddings_prod * self.kernel_inter
1137+
h_mf = reduce_sum(field_weighted_embedding, axis=1)
1138+
h_mf = tf.nn.bias_add(h_mf, self.bias_inter)
1139+
1140+
# FM module
1141+
square_of_sum_list = [
1142+
tf.square(reduce_sum(field_i_vectors, axis=1, keep_dims=True))
1143+
for field_i_vectors in field_wise_embeds_list
1144+
]
1145+
sum_of_square_list = [
1146+
reduce_sum(field_i_vectors * field_i_vectors,
1147+
axis=1,
1148+
keep_dims=True)
1149+
for field_i_vectors in field_wise_embeds_list
1150+
]
1151+
1152+
field_fm = tf.concat([
1153+
square_of_sum - sum_of_square for square_of_sum, sum_of_square in
1154+
zip(square_of_sum_list, sum_of_square_list)
1155+
], 1)
1156+
1157+
h_fm = reduce_sum(field_fm * self.kernel_intra, axis=1)
1158+
1159+
h_fm = tf.nn.bias_add(h_fm, self.bias_intra)
1160+
1161+
return h_mf + h_fm
1162+
1163+
def compute_output_shape(self, input_shape):
1164+
return (None, input_shape[0][-1])

deepctr/models/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .fgcnn import FGCNN
1717
from .dsin import DSIN
1818
from .fibinet import FiBiNET
19+
from .flen import FLEN
1920

2021
__all__ = ["AFM", "CCPM","DCN", "MLR", "DeepFM",
21-
"MLR", "NFM", "DIN", "DIEN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET"]
22+
"MLR", "NFM", "DIN", "DIEN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN']

deepctr/models/flen.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# -*- coding:utf-8 -*-
2+
"""
3+
Author:
4+
Tingyi Tan,[email protected]
5+
6+
Reference:
7+
[1] hen W, Zhan L, Ci Y, Lin C https://arxiv.org/pdf/1911.04690
8+
9+
"""
10+
11+
12+
from itertools import chain
13+
import tensorflow as tf
14+
from tensorflow.python.keras.layers import Flatten
15+
16+
from ..inputs import input_from_feature_columns, get_linear_logit, build_input_features, combined_dnn_input
17+
from ..layers.core import PredictionLayer, DNN
18+
from ..layers.utils import concat_func, add_func
19+
from ..layers.interaction import FieldWiseBiInteraction
20+
21+
22+
def FLEN(linear_feature_columns,
23+
dnn_feature_columns,
24+
l2_reg_linear=0.00001,
25+
l2_reg_embedding=0.00001,
26+
l2_reg_dnn=0.00001,
27+
l2_reg_fw=0.00001,
28+
init_std=0.0001,
29+
seed=1024,
30+
dnn_dropout=0.2,
31+
dnn_activation='relu',
32+
dnn_use_bn=True,
33+
task='binary'):
34+
"""Instantiates the DeepFM Network architecture.
35+
36+
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
37+
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
38+
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
39+
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
40+
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
41+
:param l2_reg_fw: float. L2 regularizer strength applied to fwfm
42+
:param init_std: float,to use as the initialize std of embedding vector
43+
:param seed: integer ,to use as random seed.
44+
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
45+
:param dnn_activation: Activation function to use in DNN
46+
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN
47+
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
48+
:return: A Keras model instance.
49+
"""
50+
51+
features = build_input_features(linear_feature_columns +
52+
dnn_feature_columns)
53+
54+
inputs_list = list(features.values())
55+
56+
group_embedding_dict, dense_value_list = input_from_feature_columns(
57+
features,
58+
dnn_feature_columns,
59+
l2_reg_embedding,
60+
init_std,
61+
seed,
62+
support_group=True)
63+
64+
# S
65+
linear_logit = get_linear_logit(features,
66+
linear_feature_columns,
67+
init_std=init_std,
68+
seed=seed,
69+
prefix='linear',
70+
l2_reg=l2_reg_linear)
71+
linear_logit = Flatten()(linear_logit)
72+
73+
# FM + MF
74+
fm_mf_out = FieldWiseBiInteraction(l2_reg=l2_reg_fw, seed=seed)(
75+
[concat_func(v, axis=1) for k, v in group_embedding_dict.items()])
76+
fm_mf_out = DNN((32,), dnn_activation, l2_reg_dnn, dnn_dropout,
77+
dnn_use_bn, seed)(fm_mf_out)
78+
79+
# MLP
80+
mlp_input = combined_dnn_input(
81+
list(chain.from_iterable(group_embedding_dict.values())),
82+
dense_value_list)
83+
mlp_output = DNN((64,), dnn_activation, l2_reg_dnn, dnn_dropout,
84+
dnn_use_bn, seed)(mlp_input)
85+
mlp_output = DNN((32,), dnn_activation, l2_reg_dnn, dnn_dropout,
86+
dnn_use_bn, seed)(mlp_output)
87+
88+
# DNN
89+
dnn_input = combined_dnn_input([fm_mf_out, mlp_output, linear_logit], dense_value_list)
90+
dnn_output = dnn_input
91+
dnn_logit = tf.keras.layers.Dense(1, use_bias=False, activation=None)(dnn_output)
92+
output = PredictionLayer(task)(dnn_logit)
93+
94+
model = tf.keras.models.Model(inputs=inputs_list, outputs=output)
95+
return model

0 commit comments

Comments
 (0)