|
| 1 | +# -*- coding:utf-8 -*- |
| 2 | +""" |
| 3 | +Author: |
| 4 | + Harshit Pande |
| 5 | +
|
| 6 | +Reference: |
| 7 | + [1] Field-Embedded Factorization Machines for Click-through Rate Prediction] |
| 8 | + (https://arxiv.org/abs/2009.09931) |
| 9 | +
|
| 10 | +""" |
| 11 | + |
| 12 | +import tensorflow as tf |
| 13 | + |
| 14 | +from ..feature_column import get_linear_logit, input_from_feature_columns |
| 15 | +from ..utils import DNN_SCOPE_NAME, deepctr_model_fn, variable_scope |
| 16 | +from ...layers.core import DNN |
| 17 | +from ...layers.interaction import FEFMLayer |
| 18 | +from ...layers.utils import concat_func, add_func, combined_dnn_input, reduce_sum |
| 19 | + |
| 20 | + |
| 21 | +def DeepFEFMEstimator(linear_feature_columns, dnn_feature_columns, |
| 22 | + dnn_hidden_units=(128, 128), l2_reg_linear=0.00001, l2_reg_embedding_feat=0.00001, |
| 23 | + l2_reg_embedding_field=0.00001, l2_reg_dnn=0, seed=1024, dnn_dropout=0.0, |
| 24 | + dnn_activation='relu', dnn_use_bn=False, task='binary', model_dir=None, |
| 25 | + config=None, linear_optimizer='Ftrl', dnn_optimizer='Adagrad', training_chief_hooks=None): |
| 26 | + """Instantiates the DeepFEFM Network architecture or the shallow FEFM architecture (Ablation support not provided |
| 27 | + as estimator is meant for production, Ablation support provided in DeepFEFM implementation in models |
| 28 | +
|
| 29 | + :param linear_feature_columns: An iterable containing all the features used by linear part of the model. |
| 30 | + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. |
| 31 | + :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN |
| 32 | + :param l2_reg_linear: float. L2 regularizer strength applied to linear part |
| 33 | + :param l2_reg_embedding_feat: float. L2 regularizer strength applied to embedding vector of features |
| 34 | + :param l2_reg_embedding_field: float, L2 regularizer to field embeddings |
| 35 | + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN |
| 36 | + :param seed: integer ,to use as random seed. |
| 37 | + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. |
| 38 | + :param dnn_activation: Activation function to use in DNN |
| 39 | + :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN |
| 40 | + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss |
| 41 | + :param model_dir: Directory to save model parameters, graph and etc. This can |
| 42 | + also be used to load checkpoints from the directory into a estimator |
| 43 | + to continue training a previously saved model. |
| 44 | + :param config: tf.RunConfig object to configure the runtime settings. |
| 45 | + :param linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to |
| 46 | + the linear part of the model. Defaults to FTRL optimizer. |
| 47 | + :param dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to |
| 48 | + the deep part of the model. Defaults to Adagrad optimizer. |
| 49 | + :param training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to |
| 50 | + run on the chief worker during training. |
| 51 | + :return: A Tensorflow Estimator instance. |
| 52 | + """ |
| 53 | + |
| 54 | + def _model_fn(features, labels, mode, config): |
| 55 | + train_flag = (mode == tf.estimator.ModeKeys.TRAIN) |
| 56 | + |
| 57 | + linear_logits = get_linear_logit(features, linear_feature_columns, l2_reg_linear=l2_reg_linear) |
| 58 | + final_logit_components = [linear_logits] |
| 59 | + |
| 60 | + with variable_scope(DNN_SCOPE_NAME): |
| 61 | + sparse_embedding_list, dense_value_list = input_from_feature_columns(features, dnn_feature_columns, |
| 62 | + l2_reg_embedding=l2_reg_embedding_feat) |
| 63 | + |
| 64 | + fefm_interaction_embedding = FEFMLayer( |
| 65 | + regularizer=l2_reg_embedding_field)(concat_func(sparse_embedding_list, axis=1)) |
| 66 | + |
| 67 | + fefm_logit = tf.keras.layers.Lambda(lambda x: reduce_sum(x, axis=1, keep_dims=True))( |
| 68 | + fefm_interaction_embedding) |
| 69 | + |
| 70 | + final_logit_components.append(fefm_logit) |
| 71 | + |
| 72 | + if dnn_hidden_units: |
| 73 | + dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list) |
| 74 | + dnn_input = concat_func([dnn_input, fefm_interaction_embedding], axis=1) |
| 75 | + |
| 76 | + dnn_output = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, seed=seed)( |
| 77 | + dnn_input, training=train_flag) |
| 78 | + |
| 79 | + dnn_logit = tf.keras.layers.Dense( |
| 80 | + 1, use_bias=False, kernel_initializer=tf.keras.initializers.glorot_normal(seed))(dnn_output) |
| 81 | + |
| 82 | + final_logit_components.append(dnn_logit) |
| 83 | + |
| 84 | + logits = add_func(final_logit_components) |
| 85 | + |
| 86 | + return deepctr_model_fn(features, mode, logits, labels, task, linear_optimizer, dnn_optimizer, |
| 87 | + training_chief_hooks=training_chief_hooks) |
| 88 | + |
| 89 | + return tf.estimator.Estimator(_model_fn, model_dir=model_dir, config=config) |
0 commit comments