Skip to content

Commit 82efc6f

Browse files
committed
minor change
1. Modified output shape of BiInteractionPooling, InnerProductLayer to make sure that the dimensions of the output and input of a layer are same if possible 2. Minimize the nesting of other layers in the custom layer, because I found that the statistics of the parameters number of model.summary() are incorrect when using other layers.
1 parent 34e8fa5 commit 82efc6f

File tree

7 files changed

+61
-43
lines changed

7 files changed

+61
-43
lines changed

deepctr/activations.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def build(self, input_shape):
3434

3535
def call(self, inputs, **kwargs):
3636

37-
inputs_normed = BatchNormalization(
38-
axis=self.axis, epsilon=self.epsilon, center=False, scale=False)(inputs)
37+
inputs_normed = tf.layers.batch_normalization(inputs,axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
3938
x_p = tf.sigmoid(inputs_normed)
4039
return self.alphas * (1.0 - x_p) * inputs + x_p * inputs
4140
def get_config(self,):

deepctr/layers.py

+53-34
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from tensorflow.python.keras.layers import Layer,Dense,Activation,Dropout,BatchNormalization,concatenate
1+
from tensorflow.python.keras.layers import Layer,Activation,BatchNormalization
22
from tensorflow.python.keras.regularizers import l2
33
from tensorflow.python.keras.initializers import RandomNormal,Zeros,glorot_normal,glorot_uniform
44
from tensorflow.python.keras import backend as K
5-
from tensorflow.python.keras.activations import softmax
5+
66
import tensorflow as tf
77

88

@@ -101,12 +101,14 @@ def build(self, input_shape):
101101

102102
embedding_size = input_shape[0][-1]
103103

104-
#self.attention_W = self.add_weight(shape=(embedding_size, self.attention_factor), initializer=glorot_normal(seed=self.seed),
105-
# name="attention_W")
106-
#self.attention_b = self.add_weight(shape=(self.attention_factor,), initializer=Zeros(), name="attention_b")
104+
self.attention_W = self.add_weight(shape=(embedding_size, self.attention_factor), initializer=glorot_normal(seed=self.seed),regularizer=l2(self.l2_reg_w),
105+
name="attention_W")
106+
self.attention_b = self.add_weight(shape=(self.attention_factor,), initializer=Zeros(), name="attention_b")
107107
self.projection_h = self.add_weight(shape=(self.attention_factor, 1), initializer=glorot_normal(seed=self.seed),
108108
name="projection_h")
109109
self.projection_p = self.add_weight(shape=(embedding_size, 1), initializer=glorot_normal(seed=self.seed), name="projection_p")
110+
111+
110112
super(AFMLayer, self).build(input_shape) # Be sure to call this somewhere!
111113

112114
def call(self, inputs,**kwargs):
@@ -127,14 +129,14 @@ def call(self, inputs,**kwargs):
127129
inner_product = p * q
128130

129131
bi_interaction = inner_product
132+
attention_temp = tf.nn.relu(tf.nn.bias_add(tf.tensordot(bi_interaction,self.attention_W,axes=(-1,0)),self.attention_b))
133+
# Dense(self.attention_factor,'relu',kernel_regularizer=l2(self.l2_reg_w))(bi_interaction)
134+
attention_weight =tf.nn.softmax(tf.tensordot(attention_temp,self.projection_h,axes=(-1,0)),dim=1)
135+
attention_output = tf.reduce_sum(attention_weight*bi_interaction,axis=1)
130136

131-
attention_temp = Dense(self.attention_factor,'relu',kernel_regularizer=l2(self.l2_reg_w))(bi_interaction)
132-
attention_weight = softmax(K.dot(attention_temp, self.projection_h),axis=1)
133-
134-
attention_output = K.sum(attention_weight*bi_interaction,axis=1)
135137
attention_output = tf.nn.dropout(attention_output,self.keep_prob,seed=1024)
136-
# Dropout(1-self.keep_prob)(attention_output)
137-
afm_out = K.dot(attention_output, self.projection_p)
138+
# Dropout(1-self.keep_prob)(attention_output)
139+
afm_out = tf.tensordot(attention_output,self.projection_p,axes=(-1,0))
138140

139141
return afm_out
140142

@@ -169,13 +171,14 @@ def build(self, input_shape):
169171
def call(self, inputs,**kwargs):
170172
x = inputs
171173
if self.use_bias:
172-
x = K.bias_add(x, self.global_bias, data_format='channels_last')
174+
x = tf.nn.bias_add(x,self.global_bias,data_format='NHWC')
175+
173176
if isinstance(self.activation,str):
174177
output = Activation(self.activation)(x)
175178
else:
176179
output = self.activation(x)
177180

178-
output = K.reshape(output,(-1,1))
181+
output = tf.reshape(output,(-1,1))
179182

180183
return output
181184

@@ -282,17 +285,28 @@ def __init__(self, hidden_size, activation,l2_reg, keep_prob, use_bn,seed,**kwa
282285
super(MLP, self).__init__(**kwargs)
283286

284287
def build(self, input_shape):
288+
input_size = input_shape[-1]
289+
hidden_units = [int(input_size)] + self.hidden_size
290+
self.kernels = [self.add_weight(name='kernel' + str(i),
291+
shape=(hidden_units[i], hidden_units[i+1]),
292+
initializer=glorot_normal(seed=self.seed),
293+
regularizer=l2(self.l2_reg),
294+
trainable=True) for i in range(len(self.hidden_size))]
295+
self.bias = [self.add_weight(name='bias' + str(i),
296+
shape=(self.hidden_size[i],),
297+
initializer=Zeros(),
298+
trainable=True) for i in range(len(self.hidden_size))]
285299

286300
super(MLP, self).build(input_shape) # Be sure to call this somewhere!
287301

288302
def call(self, inputs,**kwargs):
289303
deep_input = inputs
290-
#deep_input = Dropout(1 - self.keep_prob)(deep_input)
291304

292-
for l in range(len(self.hidden_size)):
293-
fc = Dense(self.hidden_size[l], activation=None, \
294-
kernel_initializer=glorot_normal(seed=self.seed), \
295-
kernel_regularizer=l2(self.l2_reg))(deep_input)
305+
for i in range(len(self.hidden_size)):
306+
fc = tf.nn.bias_add(tf.tensordot(deep_input,self.kernels[i],axes=(-1,0)),self.bias[i])
307+
#fc = Dense(self.hidden_size[i], activation=None, \
308+
# kernel_initializer=glorot_normal(seed=self.seed), \
309+
# kernel_regularizer=l2(self.l2_reg))(deep_input)
296310
if self.use_bn:
297311
fc = BatchNormalization()(fc)
298312

@@ -302,7 +316,7 @@ def call(self, inputs,**kwargs):
302316
fc = self.activation()(fc)
303317
else:
304318
raise ValueError("Invalid activation of MLP,found %s.You should use a str or a Activation Layer Class."%(self.activation))
305-
fc = Dropout(1 - self.keep_prob)(fc)
319+
fc = tf.nn.dropout(fc,self.keep_prob)
306320

307321
deep_input = fc
308322

@@ -327,7 +341,7 @@ class BiInteractionPooling(Layer):
327341
- A list of 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
328342
329343
Output shape
330-
- 2D tensor with shape: ``(batch_size, embedding_size)``.
344+
- 3D tensor with shape: ``(batch_size,1,embedding_size)``.
331345
332346
References
333347
- [Neural Factorization Machines for Sparse Predictive Analytics](http://arxiv.org/abs/1708.05027)
@@ -350,14 +364,14 @@ def call(self, inputs,**kwargs):
350364
raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions"% (K.ndim(inputs)))
351365

352366
concated_embeds_value = inputs
353-
square_of_sum = K.square(K.sum(concated_embeds_value, axis=1, keepdims=True))
354-
sum_of_square = K.sum(concated_embeds_value * concated_embeds_value, axis=1, keepdims=True)
367+
square_of_sum = tf.square(tf.reduce_sum(concated_embeds_value, axis=1, keep_dims=True))
368+
sum_of_square = tf.reduce_sum(concated_embeds_value * concated_embeds_value, axis=1, keep_dims=True)
355369
cross_term = 0.5*(square_of_sum - sum_of_square)
356-
cross_term = K.reshape(cross_term,(-1,inputs.get_shape()[-1]))
370+
357371
return cross_term
358372

359373
def compute_output_shape(self, input_shape):
360-
return (None, input_shape[-1])
374+
return (None, 1, input_shape[-1])
361375

362376
class OutterProductLayer(Layer):
363377
"""OutterProduct Layer used in PNN.This implemention is adapted from code that the author of the paper published on https://github.com/Atomu2014/product-nets.
@@ -366,7 +380,7 @@ class OutterProductLayer(Layer):
366380
- A list of N 3D tensor with shape: ``(batch_size,1,embedding_size)``.
367381
368382
Output shape
369-
- 2D tensor with shape:``(batch_size, N*(N-1)/2 )``.
383+
- 2D tensor with shape:``(batch_size,N*(N-1)/2 )``.
370384
371385
Arguments
372386
- **kernel_type**: str. The kernel weight matrix type to use,can be mat,vec or num
@@ -434,8 +448,8 @@ def call(self, inputs,**kwargs):
434448
for j in range(i + 1, num_inputs):
435449
row.append(i)
436450
col.append(j)
437-
p = K.concatenate([embed_list[idx] for idx in row],axis=1) # batch num_pairs k
438-
q = K.concatenate([embed_list[idx] for idx in col],axis=1) # Reshape([num_pairs, self.embedding_size])
451+
p = tf.concat([embed_list[idx] for idx in row],axis=1) # batch num_pairs k
452+
q = tf.concat([embed_list[idx] for idx in col],axis=1) # Reshape([num_pairs, self.embedding_size])
439453

440454
#-------------------------
441455
if self.kernel_type == 'mat':
@@ -499,7 +513,7 @@ class InnerProductLayer(Layer):
499513
- A list of N 3D tensor with shape: ``(batch_size,1,embedding_size)``.
500514
501515
Output shape
502-
- 2D tensor with shape: ``(batch_size, N*(N-1)/2 )`` if use reduce_sum. or 3D tensor with shape: ``(batch_size, N*(N-1)/2, embedding_size )`` if not use reduce_sum.
516+
- 3D tensor with shape: ``(batch_size, N*(N-1)/2 ,1)`` if use reduce_sum. or 3D tensor with shape: ``(batch_size, N*(N-1)/2, embedding_size )`` if not use reduce_sum.
503517
504518
Arguments
505519
- **reduce_sum**: bool. Whether return inner product or element-wise product
@@ -550,11 +564,11 @@ def call(self, inputs,**kwargs):
550564
for j in range(i + 1, num_inputs):
551565
row.append(i)
552566
col.append(j)
553-
p = K.concatenate([embed_list[idx] for idx in row],axis=1)# batch num_pairs k
554-
q = K.concatenate([embed_list[idx] for idx in col],axis=1) # Reshape([num_pairs, self.embedding_size])
567+
p = tf.concat([embed_list[idx] for idx in row],axis=1)# batch num_pairs k
568+
q = tf.concat([embed_list[idx] for idx in col],axis=1) # Reshape([num_pairs, self.embedding_size])
555569
inner_product = p * q
556570
if self.reduce_sum:
557-
inner_product = K.sum(inner_product, axis=2, keepdims=False)
571+
inner_product = tf.reduce_sum(inner_product, axis=2, keep_dims=True)
558572
return inner_product
559573

560574

@@ -564,7 +578,7 @@ def compute_output_shape(self, input_shape):
564578
input_shape = input_shape[0]
565579
embed_size = input_shape[-1]
566580
if self.reduce_sum:
567-
return (input_shape[0],num_pairs)
581+
return (input_shape[0],num_pairs,1)
568582
else:
569583
return (input_shape[0],num_pairs,embed_size)
570584

@@ -623,6 +637,11 @@ def build(self, input_shape):
623637
raise ValueError('A `LocalActivationUnit` layer requires '
624638
'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)'
625639
'Got different shapes: %s,%s' % (input_shape))
640+
size = 4*int(input_shape[0][-1]) if len(self.hidden_size) == 0 else self.hidden_size[-1]
641+
self.kernel = self.add_weight(shape=(size, 1),
642+
initializer=glorot_normal(seed=self.seed),
643+
name="kernel")
644+
self.bias = self.add_weight(shape=(1,), initializer=Zeros(), name="bias")
626645
super(LocalActivationUnit, self).build(input_shape) # Be sure to call this somewhere!
627646

628647
def call(self, inputs,**kwargs):
@@ -634,9 +653,9 @@ def call(self, inputs,**kwargs):
634653
queries = K.repeat_elements(query,keys_len,1)
635654

636655
att_input = tf.concat([queries, keys, queries - keys, queries * keys], axis=-1)
637-
att_input = BatchNormalization()(att_input)
656+
att_input = tf.layers.batch_normalization(att_input)
638657
att_out = MLP(self.hidden_size, self.activation, self.l2_reg, self.keep_prob, self.use_bn, seed=self.seed)(att_input)
639-
attention_score = Dense(1, 'linear')(att_out)
658+
attention_score = tf.nn.bias_add(tf.tensordot(att_out,self.kernel,axes=(-1,0)),self.bias)
640659

641660
return attention_score
642661

deepctr/models/din.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_input(feature_dim_dict, seq_feature_list, seq_max_len):
3030

3131

3232
def DIN(feature_dim_dict, seq_feature_list, embedding_size=8, hist_len_max=16,
33-
use_din=True, use_bn=False, hidden_size=[200, 80], activation='relu', att_hidden_size=[80, 40], att_activation='sigmoid', att_weight_normalization=True,
33+
use_din=True, use_bn=False, hidden_size=[200, 80], activation='relu', att_hidden_size=[80, 40], att_activation=Dice, att_weight_normalization=False,
3434
l2_reg_deep=0, l2_reg_embedding=1e-5, final_activation='sigmoid', keep_prob=1, init_std=0.0001, seed=1024, ):
3535
"""Instantiates the Deep Interest Network architecture.
3636

deepctr/models/pnn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
[1] Qu, Yanru, et al. "Product-based neural networks for user response prediction." Data Mining (ICDM), 2016 IEEE 16th International Conference on. IEEE, 2016.(https://arxiv.org/pdf/1611.00144.pdf)
88
"""
99

10-
from tensorflow.python.keras.layers import Dense, Embedding, Concatenate, Reshape
10+
from tensorflow.python.keras.layers import Dense, Embedding, Concatenate, Reshape,Flatten
1111
from tensorflow.python.keras.models import Model
1212
from tensorflow.python.keras.initializers import RandomNormal
1313
from tensorflow.python.keras.regularizers import l2
@@ -63,7 +63,7 @@ def PNN(feature_dim_dict, embedding_size=8, hidden_size=[128, 128], l2_reg_embed
6363
map(Reshape((1, embedding_size)), continuous_embedding_list))
6464
embed_list += continuous_embedding_list
6565

66-
inner_product = InnerProductLayer()(embed_list)
66+
inner_product = Flatten()(InnerProductLayer()(embed_list))
6767
outter_product = OutterProductLayer(kernel_type)(embed_list)
6868

6969
# ipnn deep input

deepctr/sequence.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class SequencePoolingLayer(Layer):
2323
- **mode**:str.Pooling operation to be used,can be sum,mean or max.
2424
"""
2525

26-
def __init__(self, seq_len_max, mode='sum', **kwargs):
26+
def __init__(self, seq_len_max, mode='mean', **kwargs):
2727

2828
if mode not in ['sum', 'mean', 'max']:
2929
raise ValueError("mode must be sum or mean")
@@ -91,7 +91,7 @@ class AttentionSequencePoolingLayer(Layer):
9191
- [Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf)
9292
"""
9393

94-
def __init__(self, hidden_size=(80, 40), activation='sigmoid', weight_normalization=True, **kwargs):
94+
def __init__(self, hidden_size=(80, 40), activation='sigmoid', weight_normalization=False, **kwargs):
9595

9696
self.hidden_size = hidden_size
9797
self.activation = activation

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# The short X.Y version
2727
version = ''
2828
# The full version, including alpha/beta/rc tags
29-
release = '0.1.4'
29+
release = '0.1.5'
3030

3131

3232
# -- General configuration ---------------------------------------------------

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="deepctr",
8-
version="0.1.4",
8+
version="0.1.5",
99
author="Weichen Shen",
1010
author_email="[email protected]",
1111
description="DeepCTR is a Easy-to-use,Modular and Extendible package of deep-learning based CTR models ,including serval DNN-based CTR models and lots of core components layer of the models which can be used to build your own custom model.",

0 commit comments

Comments
 (0)