Skip to content

Commit 3492ad3

Browse files
Implement the BST model. (shenweichen#327)
* Implement the BST model.
1 parent 4762e85 commit 3492ad3

File tree

6 files changed

+218
-13
lines changed

6 files changed

+218
-13
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Let's [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Star
5454
| Deep Session Interest Network | [IJCAI 2019][Deep Session Interest Network for Click-Through Rate Prediction ](https://arxiv.org/abs/1905.06482) |
5555
| FiBiNET | [RecSys 2019][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction](https://arxiv.org/pdf/1905.09433.pdf) |
5656
| FLEN | [arxiv 2019][FLEN: Leveraging Field for Scalable CTR Prediction](https://arxiv.org/pdf/1911.04690.pdf) |
57+
| BST | [DLP-KDD 2019][Behavior sequence transformer for e-commerce recommendation in Alibaba](https://arxiv.org/pdf/1905.06874.pdf) |
5758
| DCN V2 | [arxiv 2020][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535) |
5859

5960
## Citation

deepctr/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .utils import check_version
2-
3-
__version__ = '0.8.3'
4-
check_version(__version__)
1+
from .utils import check_version
2+
3+
__version__ = '0.8.3'
4+
check_version(__version__)

deepctr/layers/sequence.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def call(self, seq_value_len_list, mask=None, **kwargs):
7979
mask = tf.tile(mask, [1, 1, embedding_size])
8080

8181
if self.mode == "max":
82-
hist = uiseq_embed_list - (1-mask) * 1e9
82+
hist = uiseq_embed_list - (1 - mask) * 1e9
8383
return reduce_max(hist, 1, keep_dims=True)
8484

8585
hist = reduce_sum(uiseq_embed_list * mask, 1, keep_dims=False)
@@ -436,14 +436,16 @@ class Transformer(Layer):
436436
- **blinding**: bool. Whether or not use blinding.
437437
- **seed**: A Python integer to use as random seed.
438438
- **supports_masking**:bool. Whether or not support masking.
439+
- **attention_type**: str, Type of attention, the value must be one of ["scaled_dot_product","additive"].
440+
- **output_type**: str or None. Whether or not use average/sum pooling for output.
439441
440442
References
441443
- [Vaswani, Ashish, et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017.](https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)
442444
"""
443445

444446
def __init__(self, att_embedding_size=1, head_num=8, dropout_rate=0.0, use_positional_encoding=True, use_res=True,
445447
use_feed_forward=True, use_layer_norm=False, blinding=True, seed=1024, supports_masking=False,
446-
**kwargs):
448+
attention_type="scaled_dot_product", output_type="mean", **kwargs):
447449
if head_num <= 0:
448450
raise ValueError('head_num must be a int > 0')
449451
self.att_embedding_size = att_embedding_size
@@ -456,6 +458,8 @@ def __init__(self, att_embedding_size=1, head_num=8, dropout_rate=0.0, use_posit
456458
self.dropout_rate = dropout_rate
457459
self.use_layer_norm = use_layer_norm
458460
self.blinding = blinding
461+
self.attention_type = attention_type
462+
self.output_type = output_type
459463
super(Transformer, self).__init__(**kwargs)
460464
self.supports_masking = supports_masking
461465

@@ -464,7 +468,7 @@ def build(self, input_shape):
464468
if self.num_units != embedding_size:
465469
raise ValueError(
466470
"att_embedding_size * head_num must equal the last dimension size of inputs,got %d * %d != %d" % (
467-
self.att_embedding_size, self.head_num, embedding_size))
471+
self.att_embedding_size, self.head_num, embedding_size))
468472
self.seq_len_max = int(input_shape[0][-2])
469473
self.W_Query = self.add_weight(name='query', shape=[embedding_size, self.att_embedding_size * self.head_num],
470474
dtype=tf.float32,
@@ -475,6 +479,11 @@ def build(self, input_shape):
475479
self.W_Value = self.add_weight(name='value', shape=[embedding_size, self.att_embedding_size * self.head_num],
476480
dtype=tf.float32,
477481
initializer=tf.keras.initializers.TruncatedNormal(seed=self.seed + 2))
482+
if self.attention_type == "additive":
483+
self.b = self.add_weight('b', shape=[self.att_embedding_size], dtype=tf.float32,
484+
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
485+
self.v = self.add_weight('v', shape=[self.att_embedding_size], dtype=tf.float32,
486+
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
478487
# if self.use_res:
479488
# self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
480489
# initializer=tf.keras.initializers.TruncatedNormal(seed=self.seed))
@@ -525,10 +534,18 @@ def call(self, inputs, mask=None, training=None, **kwargs):
525534
keys = tf.concat(tf.split(keys, self.head_num, axis=2), axis=0)
526535
values = tf.concat(tf.split(values, self.head_num, axis=2), axis=0)
527536

528-
# head_num*None T_q T_k
529-
outputs = tf.matmul(querys, keys, transpose_b=True)
537+
if self.attention_type == "scaled_dot_product":
538+
# head_num*None T_q T_k
539+
outputs = tf.matmul(querys, keys, transpose_b=True)
530540

531-
outputs = outputs / (keys.get_shape().as_list()[-1] ** 0.5)
541+
outputs = outputs / (keys.get_shape().as_list()[-1] ** 0.5)
542+
elif self.attention_type == "additive":
543+
querys_reshaped = tf.expand_dims(querys, axis=-2)
544+
keys_reshaped = tf.expand_dims(keys, axis=-3)
545+
outputs = tf.tanh(tf.nn.bias_add(querys_reshaped + keys_reshaped, self.b))
546+
outputs = tf.squeeze(tf.tensordot(outputs, tf.expand_dims(self.v, axis=-1), axes=[-1, 0]), axis=-1)
547+
else:
548+
NotImplementedError
532549

533550
key_masks = tf.tile(key_masks, [self.head_num, 1])
534551

@@ -579,7 +596,12 @@ def call(self, inputs, mask=None, training=None, **kwargs):
579596
if self.use_layer_norm:
580597
result = self.ln(result)
581598

582-
return reduce_mean(result, axis=1, keep_dims=True)
599+
if self.output_type == "mean":
600+
return reduce_mean(result, axis=1, keep_dims=True)
601+
elif self.output_type == "sum":
602+
return reduce_sum(result, axis=1, keep_dims=True)
603+
else:
604+
return result
583605

584606
def compute_output_shape(self, input_shape):
585607

@@ -593,7 +615,7 @@ def get_config(self, ):
593615
'dropout_rate': self.dropout_rate, 'use_res': self.use_res,
594616
'use_positional_encoding': self.use_positional_encoding, 'use_feed_forward': self.use_feed_forward,
595617
'use_layer_norm': self.use_layer_norm, 'seed': self.seed, 'supports_masking': self.supports_masking,
596-
'blinding': self.blinding}
618+
'blinding': self.blinding, 'attention_type': self.attention_type, 'output_type': self.output_type}
597619
base_config = super(Transformer, self).get_config()
598620
return dict(list(base_config.items()) + list(config.items()))
599621

deepctr/models/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .fibinet import FiBiNET
1919
from .flen import FLEN
2020
from .fwfm import FwFM
21+
from .bst import BST
2122

2223
__all__ = ["AFM", "CCPM", "DCN", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN",
23-
"WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM"]
24+
"WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "BST"]

deepctr/models/bst.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# -*- coding:utf-8 -*-
2+
"""
3+
Author:
4+
Zichao Li, [email protected]
5+
6+
Reference:
7+
Qiwei Chen, Huan Zhao, Wei Li, Pipei Huang, and Wenwu Ou. 2019. Behavior sequence transformer for e-commerce recommendation in Alibaba. In Proceedings of the 1st International Workshop on Deep Learning Practice for High-Dimensional Sparse Data (DLP-KDD '19). Association for Computing Machinery, New York, NY, USA, Article 12, 1–4. DOI:https://doi.org/10.1145/3326937.3341261
8+
"""
9+
10+
import tensorflow as tf
11+
from tensorflow.python.keras.layers import (Dense, LeakyReLU, Flatten)
12+
from ..feature_column import SparseFeat, VarLenSparseFeat, DenseFeat, build_input_features
13+
from ..inputs import get_varlen_pooling_list, create_embedding_matrix, embedding_lookup, varlen_embedding_lookup, \
14+
get_dense_input
15+
from ..layers.core import DNN, PredictionLayer
16+
from ..layers.sequence import Transformer, AttentionSequencePoolingLayer
17+
from ..layers.utils import concat_func, combined_dnn_input
18+
19+
20+
def BST(dnn_feature_columns, history_feature_list, transformer_num=1, att_head_num=8,
21+
use_bn=False, dnn_hidden_units=(1024, 512, 256), dnn_activation='relu', l2_reg_dnn=0,
22+
l2_reg_embedding=1e-6, dnn_dropout=0.0, seed=1024, task='binary'):
23+
"""Instantiates the BST architecture.
24+
25+
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
26+
:param history_feature_list: list, to indicate sequence sparse field.
27+
:param transformer_num: int, the number of transformer layer.
28+
:param att_head_num: int, the number of heads in multi-head self attention.
29+
:param use_bn: bool. Whether use BatchNormalization before activation or not in deep net
30+
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
31+
:param dnn_activation: Activation function to use in DNN
32+
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
33+
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
34+
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
35+
:param seed: integer ,to use as random seed.
36+
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
37+
:return: A Keras model instance.
38+
39+
"""
40+
41+
features = build_input_features(dnn_feature_columns)
42+
inputs_list = list(features.values())
43+
44+
user_behavior_length = features["seq_length"]
45+
46+
sparse_feature_columns = list(
47+
filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns)) if dnn_feature_columns else []
48+
dense_feature_columns = list(
49+
filter(lambda x: isinstance(x, DenseFeat), dnn_feature_columns)) if dnn_feature_columns else []
50+
varlen_sparse_feature_columns = list(
51+
filter(lambda x: isinstance(x, VarLenSparseFeat), dnn_feature_columns)) if dnn_feature_columns else []
52+
53+
history_feature_columns = []
54+
sparse_varlen_feature_columns = []
55+
history_fc_names = list(map(lambda x: "hist_" + x, history_feature_list))
56+
57+
for fc in varlen_sparse_feature_columns:
58+
feature_name = fc.name
59+
if feature_name in history_fc_names:
60+
history_feature_columns.append(fc)
61+
else:
62+
sparse_varlen_feature_columns.append(fc)
63+
64+
embedding_dict = create_embedding_matrix(dnn_feature_columns, l2_reg_embedding, seed, prefix="",
65+
seq_mask_zero=True)
66+
67+
query_emb_list = embedding_lookup(embedding_dict, features, sparse_feature_columns,
68+
return_feat_list=history_feature_list, to_list=True)
69+
hist_emb_list = embedding_lookup(embedding_dict, features, history_feature_columns,
70+
return_feat_list=history_fc_names, to_list=True)
71+
dnn_input_emb_list = embedding_lookup(embedding_dict, features, sparse_feature_columns,
72+
mask_feat_list=history_feature_list, to_list=True)
73+
dense_value_list = get_dense_input(features, dense_feature_columns)
74+
sequence_embed_dict = varlen_embedding_lookup(embedding_dict, features, sparse_varlen_feature_columns)
75+
sequence_embed_list = get_varlen_pooling_list(sequence_embed_dict, features, sparse_varlen_feature_columns,
76+
to_list=True)
77+
78+
dnn_input_emb_list += sequence_embed_list
79+
query_emb = concat_func(query_emb_list)
80+
deep_input_emb = concat_func(dnn_input_emb_list)
81+
hist_emb = concat_func(hist_emb_list)
82+
83+
transformer_output = hist_emb
84+
for i in range(transformer_num):
85+
att_embedding_size = transformer_output.get_shape().as_list()[-1] // att_head_num
86+
transformer_layer = Transformer(att_embedding_size=att_embedding_size, head_num=att_head_num,
87+
dropout_rate=dnn_dropout, use_positional_encoding=True,use_res=True,
88+
use_feed_forward=True, use_layer_norm=True,blinding=False, seed=seed,
89+
supports_masking=False,output_type=None)
90+
transformer_output = transformer_layer([transformer_output, transformer_output,
91+
user_behavior_length, user_behavior_length])
92+
93+
attn_output = AttentionSequencePoolingLayer(att_hidden_units=(64, 16), weight_normalization=True,
94+
supports_masking=False)([query_emb, transformer_output,
95+
user_behavior_length])
96+
deep_input_emb = concat_func([deep_input_emb, attn_output], axis=-1)
97+
deep_input_emb = Flatten()(deep_input_emb)
98+
99+
dnn_input = combined_dnn_input([deep_input_emb], dense_value_list)
100+
output = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, use_bn, seed=seed)(dnn_input)
101+
final_logit = Dense(1, use_bias=False, kernel_initializer=tf.keras.initializers.glorot_normal(seed))(output)
102+
output = PredictionLayer(task)(final_logit)
103+
104+
model = tf.keras.models.Model(inputs=inputs_list, outputs=output)
105+
106+
return model

tests/models/BST_test.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
import pytest
3+
import tensorflow as tf
4+
from packaging import version
5+
6+
from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat, get_feature_names
7+
from deepctr.models import BST
8+
from ..utils import check_model
9+
10+
11+
def get_xy_fd(use_neg=False, hash_flag=False):
12+
feature_columns = [SparseFeat('user', 3, embedding_dim=12, use_hash=hash_flag),
13+
SparseFeat('gender', 2, embedding_dim=4, use_hash=hash_flag),
14+
SparseFeat('item_id', 3 + 1, embedding_dim=8, use_hash=hash_flag),
15+
SparseFeat('cate_id', 2 + 1, embedding_dim=4, use_hash=hash_flag),
16+
DenseFeat('pay_score', 1)]
17+
18+
feature_columns += [
19+
VarLenSparseFeat(SparseFeat('hist_item_id', vocabulary_size=3 + 1, embedding_dim=8, embedding_name='item_id'),
20+
maxlen=4, length_name="seq_length"),
21+
VarLenSparseFeat(SparseFeat('hist_cate_id', 2 + 1, embedding_dim=4, embedding_name='cate_id'), maxlen=4,
22+
length_name="seq_length")]
23+
24+
behavior_feature_list = ["item_id", "cate_id"]
25+
uid = np.array([0, 1, 2])
26+
ugender = np.array([0, 1, 0])
27+
iid = np.array([1, 2, 3]) # 0 is mask value
28+
cate_id = np.array([1, 2, 2]) # 0 is mask value
29+
score = np.array([0.1, 0.2, 0.3])
30+
31+
hist_iid = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0]])
32+
hist_cate_id = np.array([[1, 2, 2, 0], [1, 2, 2, 0], [1, 2, 0, 0]])
33+
34+
behavior_length = np.array([3, 3, 2])
35+
36+
feature_dict = {'user': uid, 'gender': ugender, 'item_id': iid, 'cate_id': cate_id,
37+
'hist_item_id': hist_iid, 'hist_cate_id': hist_cate_id,
38+
'pay_score': score, "seq_length": behavior_length}
39+
40+
if use_neg:
41+
feature_dict['neg_hist_item_id'] = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0]])
42+
feature_dict['neg_hist_cate_id'] = np.array([[1, 2, 2, 0], [1, 2, 2, 0], [1, 2, 0, 0]])
43+
feature_columns += [
44+
VarLenSparseFeat(
45+
SparseFeat('neg_hist_item_id', vocabulary_size=3 + 1, embedding_dim=8, embedding_name='item_id'),
46+
maxlen=4, length_name="seq_length"),
47+
VarLenSparseFeat(SparseFeat('neg_hist_cate_id', 2 + 1, embedding_dim=4, embedding_name='cate_id'),
48+
maxlen=4, length_name="seq_length")]
49+
50+
x = {name: feature_dict[name] for name in get_feature_names(feature_columns)}
51+
y = np.array([1, 0, 1])
52+
x["position_hist"] = np.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]])
53+
return x, y, feature_columns, behavior_feature_list
54+
55+
56+
# @pytest.mark.xfail(reason="There is a bug when save model use Dice")
57+
# @pytest.mark.skip(reason="misunderstood the API")
58+
59+
def test_BST():
60+
if version.parse(tf.__version__) >= version.parse('2.0.0'):
61+
tf.compat.v1.disable_eager_execution()
62+
model_name = "BST"
63+
64+
x, y, feature_columns, behavior_feature_list = get_xy_fd(hash_flag=True)
65+
66+
model = BST(dnn_feature_columns=feature_columns,
67+
history_feature_list=behavior_feature_list,
68+
att_head_num=4)
69+
70+
check_model(model, model_name, x, y,
71+
check_model_io=True)
72+
73+
74+
if __name__ == "__main__":
75+
pass

0 commit comments

Comments
 (0)