forked from shenweichen/DeepCTR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathactivation.py
72 lines (55 loc) · 2.87 KB
/
activation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding:utf-8 -*-
"""
Author:
Weichen Shen,[email protected]
"""
import sys
import tensorflow as tf
from tensorflow.python.keras.initializers import Zeros
from tensorflow.python.keras.layers import Layer
class Dice(Layer):
"""The Data Adaptive Activation Function in DIN,which can be viewed as a generalization of PReLu and can adaptively adjust the rectified point according to distribution of input data.
Input shape
- Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
Output shape
- Same shape as the input.
Arguments
- **axis** : Integer, the axis that should be used to compute data distribution (typically the features axis).
- **epsilon** : Small float added to variance to avoid dividing by zero.
References
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
"""
def __init__(self, axis=-1, epsilon=1e-9, **kwargs):
self.axis = axis
self.epsilon = epsilon
super(Dice, self).__init__(**kwargs)
def build(self, input_shape):
self.bn = tf.keras.layers.BatchNormalization(
axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
self.alphas = self.add_weight(shape=(input_shape[-1],), initializer=Zeros(
), dtype=tf.float32, name= 'dice_alpha') # name='alpha_'+self.name
super(Dice, self).build(input_shape) # Be sure to call this somewhere!
self.uses_learning_phase = True
def call(self, inputs,training=None,**kwargs):
inputs_normed = self.bn(inputs,training=training)
# tf.layers.batch_normalization(
# inputs, axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
x_p = tf.sigmoid(inputs_normed)
return self.alphas * (1.0 - x_p) * inputs + x_p * inputs
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self, ):
config = {'axis': self.axis, 'epsilon': self.epsilon}
base_config = super(Dice, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def activation_layer(activation):
if activation == "dice" or activation == "Dice":
act_layer = Dice()
elif (isinstance(activation, str)) or (sys.version_info.major == 2 and isinstance(activation, (str, unicode))):
act_layer = tf.keras.layers.Activation(activation)
elif issubclass(activation, Layer):
act_layer = activation()
else:
raise ValueError(
"Invalid activation,found %s.You should use a str or a Activation Layer Class." % (activation))
return act_layer