14
14
glorot_uniform )
15
15
from tensorflow .python .keras .layers import Layer
16
16
from tensorflow .python .keras .regularizers import l2
17
+ from tensorflow .python .layers import utils
17
18
18
- from .activation import activation_fun
19
+ from .activation import activation_layer
19
20
from .utils import concat_fun
20
21
21
22
@@ -87,6 +88,8 @@ def build(self, input_shape):
87
88
embedding_size , 1 ), initializer = glorot_normal (seed = self .seed ), name = "projection_p" )
88
89
self .dropout = tf .keras .layers .Dropout (self .dropout_rate , seed = self .seed )
89
90
91
+ self .tensordot = tf .keras .layers .Lambda (lambda x : tf .tensordot (x [0 ], x [1 ], axes = (- 1 , 0 )))
92
+
90
93
# Be sure to call this somewhere!
91
94
super (AFMLayer , self ).build (input_shape )
92
95
@@ -119,9 +122,7 @@ def call(self, inputs, training=None, **kwargs):
119
122
120
123
attention_output = self .dropout (attention_output ) # training
121
124
122
- afm_out = tf .keras .layers .Lambda (lambda x : tf .tensordot (x [0 ], x [1 ]
123
- , axes = (- 1 , 0 )))([attention_output , self .projection_p ])
124
-
125
+ afm_out = self .tensordot ([attention_output , self .projection_p ])
125
126
return afm_out
126
127
127
128
def compute_output_shape (self , input_shape ):
@@ -246,6 +247,8 @@ def build(self, input_shape):
246
247
else :
247
248
self .field_nums .append (size )
248
249
250
+ self .activation_layers = [activation_layer (self .activation ) for _ in self .layer_size ]
251
+
249
252
super (CIN , self ).build (input_shape ) # Be sure to call this somewhere!
250
253
251
254
def call (self , inputs , ** kwargs ):
@@ -275,7 +278,7 @@ def call(self, inputs, **kwargs):
275
278
276
279
curr_out = tf .nn .bias_add (curr_out , self .bias [idx ])
277
280
278
- curr_out = activation_fun ( self .activation , curr_out )
281
+ curr_out = self .activation_layers [ idx ]( curr_out )
279
282
280
283
curr_out = tf .transpose (curr_out , perm = [0 , 2 , 1 ])
281
284
@@ -783,6 +786,26 @@ def build(self, input_shape):
783
786
if len (input_shape ) != 3 :
784
787
raise ValueError (
785
788
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len (input_shape )))
789
+ self .conv_layers = []
790
+ self .pooling_layers = []
791
+ self .dense_layers = []
792
+ pooling_shape = input_shape .as_list () + [1 , ]
793
+ embedding_size = input_shape [- 1 ].value
794
+ for i in range (1 , len (self .filters ) + 1 ):
795
+ filters = self .filters [i - 1 ]
796
+ width = self .kernel_width [i - 1 ]
797
+ new_filters = self .new_maps [i - 1 ]
798
+ pooling_width = self .pooling_width [i - 1 ]
799
+ conv_output_shape = self ._conv_output_shape (pooling_shape , (width , 1 ))
800
+ pooling_shape = self ._pooling_output_shape (conv_output_shape , (pooling_width , 1 ))
801
+ self .conv_layers .append (tf .keras .layers .Conv2D (filters = filters , kernel_size = (width , 1 ), strides = (1 , 1 ),
802
+ padding = 'same' ,
803
+ activation = 'tanh' , use_bias = True , ))
804
+ self .pooling_layers .append (tf .keras .layers .MaxPooling2D (pool_size = (pooling_width , 1 )))
805
+ self .dense_layers .append (tf .keras .layers .Dense (pooling_shape [1 ] * embedding_size * new_filters ,
806
+ activation = 'tanh' , use_bias = True ))
807
+
808
+ self .flatten = tf .keras .layers .Flatten ()
786
809
787
810
super (FGCNNLayer , self ).build (
788
811
input_shape ) # Be sure to call this somewhere!
@@ -794,24 +817,24 @@ def call(self, inputs, **kwargs):
794
817
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K .ndim (inputs )))
795
818
796
819
embedding_size = inputs .shape [- 1 ].value
797
- pooling_result = tf .keras . layers . Lambda ( lambda x : tf . expand_dims (x , axis = 3 ))( inputs )
820
+ pooling_result = tf .expand_dims (inputs , axis = 3 )
798
821
799
822
new_feature_list = []
800
823
801
824
for i in range (1 , len (self .filters ) + 1 ):
802
- filters = self .filters [i - 1 ]
803
- width = self .kernel_width [i - 1 ]
804
825
new_filters = self .new_maps [i - 1 ]
805
- pooling_width = self .pooling_width [i - 1 ]
806
- conv_result = tf .keras .layers .Conv2D (filters = filters , kernel_size = (width , 1 ), strides = (1 , 1 ),
807
- padding = 'same' ,
808
- activation = 'tanh' , use_bias = True , )(pooling_result )
809
- pooling_result = tf .keras .layers .MaxPooling2D (pool_size = (pooling_width , 1 ))(conv_result )
810
- flatten_result = tf .keras .layers .Flatten ()(pooling_result )
811
- new_result = tf .keras .layers .Dense (pooling_result .shape [1 ].value * embedding_size * new_filters ,
812
- activation = 'tanh' , use_bias = True )(flatten_result )
826
+
827
+ conv_result = self .conv_layers [i - 1 ](pooling_result )
828
+
829
+ pooling_result = self .pooling_layers [i - 1 ](conv_result )
830
+
831
+ flatten_result = self .flatten (pooling_result )
832
+
833
+ new_result = self .dense_layers [i - 1 ](flatten_result )
834
+
813
835
new_feature_list .append (
814
- tf .keras .layers .Reshape ((pooling_result .shape [1 ].value * new_filters , embedding_size ))(new_result ))
836
+ tf .reshape (new_result , (- 1 , pooling_result .shape [1 ].value * new_filters , embedding_size )))
837
+
815
838
new_features = concat_fun (new_feature_list , axis = 1 )
816
839
return new_features
817
840
@@ -832,3 +855,28 @@ def get_config(self, ):
832
855
'pooling_width' : self .pooling_width }
833
856
base_config = super (FGCNNLayer , self ).get_config ()
834
857
return dict (list (base_config .items ()) + list (config .items ()))
858
+
859
+ def _conv_output_shape (self , input_shape , kernel_size ):
860
+ # channels_last
861
+ space = input_shape [1 :- 1 ]
862
+ new_space = []
863
+ for i in range (len (space )):
864
+ new_dim = utils .conv_output_length (
865
+ space [i ],
866
+ kernel_size [i ],
867
+ padding = 'same' ,
868
+ stride = 1 ,
869
+ dilation = 1 )
870
+ new_space .append (new_dim )
871
+ return ([input_shape [0 ]] + new_space + [self .filters ])
872
+
873
+ def _pooling_output_shape (self , input_shape , pool_size ):
874
+ # channels_last
875
+
876
+ rows = input_shape [1 ]
877
+ cols = input_shape [2 ]
878
+ rows = utils .conv_output_length (rows , pool_size [0 ], 'valid' ,
879
+ pool_size [0 ])
880
+ cols = utils .conv_output_length (cols , pool_size [1 ], 'valid' ,
881
+ pool_size [1 ])
882
+ return [input_shape [0 ], rows , cols , input_shape [3 ]]
0 commit comments