|
| 1 | +# -*- coding:utf-8 -*- |
| 2 | +""" |
| 3 | +Author: |
| 4 | + |
| 5 | +Reference: |
| 6 | + [1] Lu W, Yu Y, Chang Y, et al. A Dual Input-aware Factorization Machine for CTR Prediction[C] |
| 7 | + //IJCAI. 2020: 3139-3145.(https://www.ijcai.org/Proceedings/2020/0434.pdf) |
| 8 | +""" |
| 9 | + |
| 10 | +import tensorflow as tf |
| 11 | + |
| 12 | +from ..feature_column import build_input_features, get_linear_logit, input_from_feature_columns, SparseFeat, \ |
| 13 | + VarLenSparseFeat |
| 14 | +from ..layers.core import PredictionLayer, DNN |
| 15 | +from ..layers.interaction import FM, InteractingLayer |
| 16 | +from ..layers.utils import concat_func, add_func, combined_dnn_input |
| 17 | + |
| 18 | + |
| 19 | +def DIFM(linear_feature_columns, dnn_feature_columns, |
| 20 | + att_embedding_size=8, att_head_num=8, att_res=True, dnn_hidden_units=(128, 128), |
| 21 | + l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, seed=1024, dnn_dropout=0, |
| 22 | + dnn_activation='relu', dnn_use_bn=False, task='binary'): |
| 23 | + """Instantiates the DIFM Network architecture. |
| 24 | +
|
| 25 | + :param linear_feature_columns: An iterable containing all the features used by linear part of the model. |
| 26 | + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. |
| 27 | + :param att_embedding_size: integer, the embedding size in multi-head self-attention network. |
| 28 | + :param att_head_num: int. The head number in multi-head self-attention network. |
| 29 | + :param att_res: bool. Whether or not use standard residual connections before output. |
| 30 | + :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN |
| 31 | + :param l2_reg_linear: float. L2 regularizer strength applied to linear part |
| 32 | + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector |
| 33 | + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN |
| 34 | + :param seed: integer ,to use as random seed. |
| 35 | + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. |
| 36 | + :param dnn_activation: Activation function to use in DNN |
| 37 | + :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN |
| 38 | + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss |
| 39 | + :return: A Keras model instance. |
| 40 | + """ |
| 41 | + |
| 42 | + if not len(dnn_hidden_units) > 0: |
| 43 | + raise ValueError("dnn_hidden_units is null!") |
| 44 | + |
| 45 | + features = build_input_features( |
| 46 | + linear_feature_columns + dnn_feature_columns) |
| 47 | + |
| 48 | + sparse_feat_num = len(list(filter(lambda x: isinstance(x, SparseFeat) or isinstance(x, VarLenSparseFeat), |
| 49 | + dnn_feature_columns))) |
| 50 | + inputs_list = list(features.values()) |
| 51 | + |
| 52 | + sparse_embedding_list, _ = input_from_feature_columns(features, dnn_feature_columns, |
| 53 | + l2_reg_embedding, seed) |
| 54 | + |
| 55 | + if not len(sparse_embedding_list) > 0: |
| 56 | + raise ValueError("there are no sparse features") |
| 57 | + |
| 58 | + att_input = concat_func(sparse_embedding_list, axis=1) |
| 59 | + att_out = InteractingLayer(att_embedding_size, att_head_num, att_res, scaling=True)(att_input) |
| 60 | + att_out = tf.keras.layers.Flatten()(att_out) |
| 61 | + m_vec = tf.keras.layers.Dense( |
| 62 | + sparse_feat_num, use_bias=False, kernel_initializer=tf.keras.initializers.glorot_normal(seed=seed))(att_out) |
| 63 | + |
| 64 | + dnn_input = combined_dnn_input(sparse_embedding_list, []) |
| 65 | + dnn_output = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, seed=seed)(dnn_input) |
| 66 | + m_bit = tf.keras.layers.Dense( |
| 67 | + sparse_feat_num, use_bias=False, kernel_initializer=tf.keras.initializers.glorot_normal(seed=seed))(dnn_output) |
| 68 | + |
| 69 | + input_aware_factor = add_func([m_vec, m_bit]) # the complete input-aware factor m_x |
| 70 | + |
| 71 | + linear_logit = get_linear_logit(features, linear_feature_columns, seed=seed, prefix='linear', |
| 72 | + l2_reg=l2_reg_linear, sparse_feat_refine_weight=input_aware_factor) |
| 73 | + |
| 74 | + fm_input = concat_func(sparse_embedding_list, axis=1) |
| 75 | + refined_fm_input = tf.keras.layers.Lambda(lambda x: x[0] * tf.expand_dims(x[1], axis=-1))( |
| 76 | + [fm_input, input_aware_factor]) |
| 77 | + fm_logit = FM()(refined_fm_input) |
| 78 | + |
| 79 | + final_logit = add_func([linear_logit, fm_logit]) |
| 80 | + |
| 81 | + output = PredictionLayer(task)(final_logit) |
| 82 | + model = tf.keras.models.Model(inputs=inputs_list, outputs=output) |
| 83 | + return model |
0 commit comments