Skip to content

Commit cc844f3

Browse files
author
Weichen Shen
authored
Update for v0.2.1
* Add AutoInt & InteractingLayer
1 parent 1107a82 commit cc844f3

21 files changed

+346
-17
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ Through `pip install deepctr` get the package and [**Get Started!**](https://d
3333
|Neural Factorization Machine|[SIGIR 2017][Neural Factorization Machines for Sparse Predictive Analytics](https://arxiv.org/pdf/1708.05027.pdf)|
3434
|Deep Interest Network|[KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf)|
3535
|xDeepFM|[KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://arxiv.org/pdf/1803.05170.pdf)|
36+
| AutoInt|[arxiv 2018][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921)|

deepctr/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
from .import sequence
44
from . import models
55
from .utils import check_version
6-
__version__ = '0.2.0post1'
6+
__version__ = '0.2.1'
77
check_version(__version__)

deepctr/layers.py

+83-2
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def call(self, inputs, **kwargs):
286286

287287
def compute_output_shape(self, input_shape):
288288
if self.split_half:
289-
featuremap_num = sum(self.layer_size[:-1]) // 2 + self.layer_size[-1]
289+
featuremap_num = sum(
290+
self.layer_size[:-1]) // 2 + self.layer_size[-1]
290291
else:
291292
featuremap_num = sum(self.layer_size)
292293
return (None, featuremap_num)
@@ -480,7 +481,6 @@ def call(self, inputs, **kwargs):
480481
col.append(j)
481482
p = tf.concat([embed_list[idx]
482483
for idx in row], axis=1) # batch num_pairs k
483-
# Reshape([num_pairs, self.embedding_size])
484484
q = tf.concat([embed_list[idx] for idx in col], axis=1)
485485
inner_product = p * q
486486
if self.reduce_sum:
@@ -504,6 +504,87 @@ def get_config(self,):
504504
return dict(list(base_config.items()) + list(config.items()))
505505

506506

507+
508+
class InteractingLayer(Layer):
509+
"""A Layer used in AutoInt that model the correlations between different feature fields by multi-head self-attention mechanism.
510+
511+
Input shape
512+
- A 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
513+
514+
Output shape
515+
- 3D tensor with shape:``(batch_size,field_size,att_embedding_size * head_num)``.
516+
517+
518+
Arguments
519+
- **att_embedding_size**: int.The embedding size in multi-head self-attention network.
520+
- **head_num**: int.The head number in multi-head self-attention network.
521+
- **use_res**: bool.Whether or not use standard residual connections before output.
522+
- **seed**: A Python integer to use as random seed.
523+
524+
References
525+
- [Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.](https://arxiv.org/abs/1810.11921)
526+
"""
527+
def __init__(self, att_embedding_size=8, head_num=2, use_res=True, seed=1024, **kwargs):
528+
if head_num <= 0:
529+
raise ValueError('head_num must be a int > 0')
530+
self.att_embedding_size = att_embedding_size
531+
self.head_num = head_num
532+
self.use_res = use_res
533+
self.seed = seed
534+
super(InteractingLayer, self).__init__(**kwargs)
535+
536+
def build(self, input_shape):
537+
if len(input_shape) != 3:
538+
raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(input_shape)))
539+
embedding_size = input_shape[-1].value
540+
self.W_Query = self.add_weight(name='query', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
541+
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
542+
self.W_key = self.add_weight(name='key', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
543+
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
544+
self.W_Value = self.add_weight(name='value', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
545+
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
546+
if self.use_res:
547+
self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
548+
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
549+
550+
super(InteractingLayer, self).build(input_shape) # Be sure to call this somewhere!
551+
552+
def call(self, inputs, **kwargs):
553+
if K.ndim(inputs) != 3:
554+
raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs)))
555+
556+
querys = tf.tensordot(inputs, self.W_Query, axes=(-1, 0)) # None F D*head_num
557+
keys = tf.tensordot(inputs, self.W_key, axes=(-1, 0))
558+
values = tf.tensordot(inputs, self.W_Value, axes=(-1, 0))
559+
560+
querys = tf.stack(tf.split(querys, self.head_num, axis=2)) # head_num None F D
561+
keys = tf.stack(tf.split(keys, self.head_num, axis=2))
562+
values = tf.stack(tf.split(values, self.head_num, axis=2))
563+
564+
inner_product = tf.matmul(querys, keys, transpose_b=True) # head_num None F F
565+
self.normalized_att_scores = tf.nn.softmax(inner_product)
566+
567+
result = tf.matmul(self.normalized_att_scores, values)#head_num None F D
568+
result = tf.concat(tf.split(result, self.head_num, ), axis=-1)
569+
result = tf.squeeze(result, axis=0)#None F D*head_num
570+
571+
if self.use_res:
572+
result += tf.tensordot(inputs, self.W_Res, axes=(-1, 0))
573+
result = tf.nn.relu(result)
574+
575+
return result
576+
577+
def compute_output_shape(self, input_shape):
578+
579+
return (None, input_shape[1], self.att_embedding_size * self.head_num)
580+
581+
def get_config(self, ):
582+
config = {'att_embedding_size': self.att_embedding_size, 'head_num': self.head_num, 'use_res': self.use_res,
583+
'seed': self.seed}
584+
base_config = super(InteractingLayer, self).get_config()
585+
return dict(list(base_config.items()) + list(config.items()))
586+
587+
507588
class LocalActivationUnit(Layer):
508589
"""The LocalActivationUnit used in DIN with which the representation of
509590
user interests varies adaptively given different candidate items.

deepctr/models/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .pnn import PNN
99
from .wdl import WDL
1010
from .xdeepfm import xDeepFM
11+
from .autoint import AutoInt
1112

1213
__all__ = ["AFM", "DCN", "MLR", "DeepFM",
13-
"MLR", "NFM", "DIN", "FNN", "PNN", "WDL", "xDeepFM"]
14+
"MLR", "NFM", "DIN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt"]

deepctr/models/autoint.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# -*- coding:utf-8 -*-
2+
"""
3+
4+
Author:
5+
Weichen Shen,[email protected]
6+
7+
Reference:
8+
[1] Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.(https://arxiv.org/abs/1810.11921)
9+
10+
"""
11+
12+
from tensorflow.python.keras.layers import Dense, Embedding, Concatenate
13+
from tensorflow.python.keras.models import Model
14+
from tensorflow.python.keras.initializers import RandomNormal
15+
from tensorflow.python.keras.regularizers import l2
16+
import tensorflow as tf
17+
18+
from ..utils import get_input
19+
from ..layers import PredictionLayer, MLP, InteractingLayer
20+
21+
22+
def AutoInt(feature_dim_dict, embedding_size=8, att_layer_num=3, att_embedding_size=8, att_head_num=2, att_res=True, hidden_size=(256, 256), activation='relu',
23+
l2_reg_deep=0, l2_reg_embedding=1e-5, use_bn=False, keep_prob=1.0, init_std=0.0001, seed=1024,
24+
final_activation='sigmoid',):
25+
"""Instantiates the AutoInt Network architecture.
26+
27+
:param feature_dim_dict: dict,to indicate sparse field and dense field like {'sparse':{'field_1':4,'field_2':3,'field_3':2},'dense':['field_4','field_5']}
28+
:param embedding_size: positive integer,sparse feature embedding_size
29+
:param att_layer_num: int.The InteractingLayer number to be used.
30+
:param att_embedding_size: int.The embedding size in multi-head self-attention network.
31+
:param att_head_num: int.The head number in multi-head self-attention network.
32+
:param att_res: bool.Whether or not use standard residual connections before output.
33+
:param hidden_size: list,list of positive integer or empty list, the layer number and units in each layer of deep net
34+
:param activation: Activation function to use in deep net
35+
:param l2_reg_deep: float. L2 regularizer strength applied to deep net
36+
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
37+
:param use_bn: bool. Whether use BatchNormalization before activation or not.in deep net
38+
:param keep_prob: float in (0,1]. keep_prob used in deep net
39+
:param init_std: float,to use as the initialize std of embedding vector
40+
:param seed: integer ,to use as random seed.
41+
:param final_activation: output activation,usually ``'sigmoid'`` or ``'linear'``
42+
:return: A Keras model instance.
43+
"""
44+
45+
if len(hidden_size) <= 0 and att_layer_num <= 0:
46+
raise ValueError("Either hidden_layer or att_layer_num must > 0")
47+
if not isinstance(feature_dim_dict, dict) or "sparse" not in feature_dim_dict or "dense" not in feature_dim_dict:
48+
raise ValueError(
49+
"feature_dim must be a dict like {'sparse':{'field_1':4,'field_2':3,'field_3':2},'dense':['field_5',]}")
50+
51+
sparse_input, dense_input = get_input(feature_dim_dict, None,)
52+
sparse_embedding = get_embeddings(
53+
feature_dim_dict, embedding_size, init_std, seed, l2_reg_embedding)
54+
embed_list = [sparse_embedding[i](sparse_input[i])
55+
for i in range(len(sparse_input))]
56+
57+
att_input = Concatenate(axis=1)(embed_list) if len(
58+
embed_list) > 1 else embed_list[0]
59+
60+
for i in range(att_layer_num):
61+
att_input = InteractingLayer(
62+
att_embedding_size, att_head_num, att_res)(att_input)
63+
att_output = tf.keras.layers.Flatten()(att_input)
64+
65+
deep_input = tf.keras.layers.Flatten()(Concatenate()(embed_list)
66+
if len(embed_list) > 1 else embed_list[0])
67+
if len(dense_input) > 0:
68+
if len(dense_input) == 1:
69+
continuous_list = dense_input[0]
70+
else:
71+
continuous_list = Concatenate()(dense_input)
72+
73+
deep_input = Concatenate()([deep_input, continuous_list])
74+
75+
if len(hidden_size) > 0 and att_layer_num > 0: # Deep & Interacting Layer
76+
deep_out = MLP(hidden_size, activation, l2_reg_deep, keep_prob,
77+
use_bn, seed)(deep_input)
78+
stack_out = Concatenate()([att_output, deep_out])
79+
final_logit = Dense(1, use_bias=False, activation=None)(stack_out)
80+
elif len(hidden_size) > 0: # Only Deep
81+
deep_out = MLP(hidden_size, activation, l2_reg_deep, keep_prob,
82+
use_bn, seed)(deep_input)
83+
final_logit = Dense(1, use_bias=False, activation=None)(deep_out)
84+
elif att_layer_num > 0: # Only Interacting Layer
85+
final_logit = Dense(1, use_bias=False, activation=None)(att_output)
86+
else: # Error
87+
raise NotImplementedError
88+
89+
output = PredictionLayer(final_activation)(final_logit)
90+
model = Model(inputs=sparse_input + dense_input, outputs=output)
91+
92+
return model
93+
94+
95+
def get_embeddings(feature_dim_dict, embedding_size, init_std, seed, l2_rev_V):
96+
sparse_embedding = [Embedding(feature_dim_dict["sparse"][feat], embedding_size,
97+
embeddings_initializer=RandomNormal(
98+
mean=0.0, stddev=init_std, seed=seed),
99+
embeddings_regularizer=l2(l2_rev_V),
100+
name='sparse_emb_' + str(i) + '-' + feat) for i, feat in
101+
enumerate(feature_dim_dict["sparse"])]
102+
103+
return sparse_embedding

deepctr/models/xdeepfm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def xDeepFM(feature_dim_dict, embedding_size=8, hidden_size=(256, 256), cin_laye
2020
:param embedding_size: positive integer,sparse feature embedding_size
2121
:param hidden_size: list,list of positive integer or empty list, the layer number and units in each layer of deep net
2222
:param cin_layer_size: list,list of positive integer or empty list, the feature maps in each hidden layer of Compressed Interaction Network
23-
:param cin_split_half: bool.if set to False, half of the feature maps in each hidden will connect to output unit
23+
:param cin_split_half: bool.if set to True, half of the feature maps in each hidden will connect to output unit
2424
:param cin_activation: activation function used on feature maps
2525
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
2626
:param l2_reg_embedding: L2 regularizer strength applied to embedding vector

deepctr/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
'Dice': Dice,
2828
'SequencePoolingLayer': SequencePoolingLayer,
2929
'AttentionSequencePoolingLayer': AttentionSequencePoolingLayer,
30-
'CIN': CIN, }
30+
'CIN': CIN,
31+
'InteractingLayer': InteractingLayer}
3132

3233

3334
def get_input(feature_dim_dict, bias_feature_dim_dict=None):

docs/pics/AutoInt.png

80 KB
Loading

docs/pics/InteractingLayer.png

46.8 KB
Loading

docs/source/FAQ.rst

+24-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
FAQ
22
==========
3-
1. How to save or load weights/models?
4-
3+
1. Save or load weights/models
4+
----------------------------------------
55
To save/load weights,you can write codes just like any other keras models.
66

77
.. code-block:: python
@@ -22,8 +22,26 @@ To save/load models,just a little different.
2222
from deepctr.utils import custom_objects
2323
model = load_model('DeepFM.h5',custom_objects)# load_model,just add a parameter
2424
25-
2. How can I get the attentional weights of feature interactions in AFM?
25+
2. Set learning rate and use earlystopping
26+
---------------------------------------------------
27+
You can use any models in DeepCTR like a keras model object.
28+
Here is a example of how to set learning rate and earlystopping:
29+
30+
.. code-block:: python
31+
32+
import deepctr
33+
from tensorflow.python.keras.optimizers import Adam,Adagrad
34+
from tensorflow.python.keras.callbacks import EarlyStopping
2635
36+
model = deepctr.models.DeepFM({"sparse": sparse_feature_dict, "dense": dense_feature_list})
37+
model.compile(Adagrad('0.0808'),'binary_crossentropy',metrics=['binary_crossentropy'])
38+
39+
es = EarlyStopping(monitor='val_binary_crossentropy')
40+
history = model.fit(model_input, data[target].values,batch_size=256, epochs=10, verbose=2, validation_split=0.2,callbacks=[es] )
41+
42+
43+
3. Get the attentional weights of feature interactions in AFM
44+
--------------------------------------------------------------------------
2745
First,make sure that you have install the latest version of deepctr.
2846

2947
Then,use the following code,the ``attentional_weights[:,i,0]`` is the ``feature_interactions[i]``'s attentional weight of all samples.
@@ -46,7 +64,7 @@ Then,use the following code,the ``attentional_weights[:,i,0]`` is the ``feature_
4664
4765
4866
49-
3. Does the models support multi-value input?
50-
67+
4. Does the models support multi-value input?
68+
---------------------------------------------------
5169
Now only the `DIN <Features.html#din-deep-interest-network>`_ model support multi-value input,you can use layers in `sequence <deepctr.sequence.html>`_ to build your own models!
52-
And I will add the feature soon~
70+
And it will be supported in a future release

docs/source/Features.rst

+21-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ DIN use a local activation unit to get the activation score between candidate it
163163
User's interest are represented by weighted sum of user behaviors.
164164
user's interest vector and other embedding vectors are concatenated and fed into a MLP to get the prediction.
165165

166-
**DIN api** `link <./deepctr.models.din.html>`_
166+
**DIN api** `link <./deepctr.models.din.html>`_ **DIN demo** `link <https://github.com/shenweichen/DeepCTR/tree/master/examples
167+
/run_din.py>`_
167168

168169
.. image:: ../pics/DIN.png
169170
:align: center
@@ -191,6 +192,25 @@ Finally,apply sum pooling on all the feature maps :math:`H_k` to get one vector.
191192

192193
`Lian J, Zhou X, Zhang F, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems[J]. arXiv preprint arXiv:1803.05170, 2018. <https://arxiv.org/pdf/1803.05170.pdf>`_
193194

195+
AutoInt(Automatic Feature Interaction)
196+
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
197+
198+
AutoInt use a interacting layer to model the interactions between different features.
199+
Within each interacting layer, each feature is allowed to interact with all the other features and is able to automatically identify relevant features to form meaningful higher-order features via the multi-head attention mechanism.
200+
By stacking multiple interacting layers,AutoInt is able to model different orders of feature interactions.
201+
202+
**AutoInt api** `link <./deepctr.models.autoint.html>`_
203+
204+
.. image:: ../pics/InteractingLayer.png
205+
:align: center
206+
:scale: 70 %
207+
208+
.. image:: ../pics/AutoInt.png
209+
:align: center
210+
:scale: 70 %
211+
212+
`Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018. <https://arxiv.org/abs/1810.11921>`_
213+
194214
Layers
195215
--------
196216

docs/source/History.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# History
2+
- 12/27/2018 : [v0.2.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.1) released.Add [AutoInt](./Features.html#autoint-automatic-feature-interactiont) Model.
23
- 12/22/2018 : [v0.2.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.0) released.Add [xDeepFM](./Features.html#xdeepfm) and automatic check for new version.
34
- 12/19/2018 : [v0.1.6](https://github.com/shenweichen/DeepCTR/releases/tag/v0.1.6) released.Now DeepCTR is compatible with tensorflow from `1.4-1.12` except for `1.7` and `1.8`.
45
- 29/11/2018 : [v0.1.4](https://github.com/shenweichen/DeepCTR/releases/tag/v0.1.4) released.Add [FAQ](./FAQ.html) in docs

docs/source/Models-API.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ DeepCTR Models API
1212
AFM<deepctr.models.afm>
1313
DCN<deepctr.models.dcn>
1414
DIN<deepctr.models.din>
15-
xDeepFM<deepctr.models.xdeepfm>
15+
xDeepFM<deepctr.models.xdeepfm>
16+
AutoInt<deepctr.models.autoint>

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# The short X.Y version
2727
version = ''
2828
# The full version, including alpha/beta/rc tags
29-
release = '0.2.0'
29+
release = '0.2.1'
3030

3131

3232
# -- General configuration ---------------------------------------------------
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
deepctr.models.autoint module
2+
=============================
3+
4+
.. automodule:: deepctr.models.autoint
5+
:members:
6+
:no-undoc-members:
7+
:no-show-inheritance:

docs/source/deepctr.models.rst

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Submodules
77
.. toctree::
88

99
deepctr.models.afm
10+
deepctr.models.autoint
1011
deepctr.models.dcn
1112
deepctr.models.deepfm
1213
deepctr.models.din

docs/source/index.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Welcome to DeepCTR's documentation!
1515
.. _Stars: https://github.com/shenweichen/DeepCTR
1616

1717
.. |Forks| image:: https://img.shields.io/github/forks/shenweichen/deepctr.svg
18-
.. _Forks: https://github.com/shenweichen/DeepCTR
18+
.. _Forks: https://github.com/shenweichen/DeepCTR/fork
1919

2020
.. |PyPi| image:: https://img.shields.io/pypi/v/deepctr.svg
2121
.. _PyPi: https://pypi.org/project/deepctr/
@@ -35,6 +35,7 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR
3535

3636
News
3737
-----
38+
12/27/2018 : Add `AutoInt <./Features.html#autoint-automatic-feature-interaction>`_ . `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.1>`_
3839

3940
12/22/2018 : Add `xDeepFM <./Features.html#xdeepfm>`_ and automatic check for new version. `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.0>`_
4041

0 commit comments

Comments
 (0)