Skip to content

Commit ab595ec

Browse files
authored
Add a transform_fn into DenseFeat (shenweichen#309)
* Add transform_fn for DenseFeat
1 parent e9c8f08 commit ab595ec

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

deepctr/feature_column.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,21 @@ def __hash__(self):
8787
return self.name.__hash__()
8888

8989

90-
class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype'])):
90+
class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype', 'transform_fn'])):
91+
""" Dense feature
92+
Args:
93+
name: feature name,
94+
dimension: dimension of the feature, default = 1.
95+
dtype: dtype of the feature, default="float32".
96+
transform_fn: If not None, a function that can be used to transfrom
97+
values of the feature. the function takes the input Tensor as its
98+
argument, and returns the output Tensor.
99+
(e.g. lambda x: (x - 3.0) / 4.2).
100+
"""
91101
__slots__ = ()
92102

93-
def __new__(cls, name, dimension=1, dtype="float32"):
94-
return super(DenseFeat, cls).__new__(cls, name, dimension, dtype)
103+
def __new__(cls, name, dimension=1, dtype="float32", transform_fn=None):
104+
return super(DenseFeat, cls).__new__(cls, name, dimension, dtype, transform_fn)
95105

96106
def __hash__(self):
97107
return self.name.__hash__()

deepctr/inputs.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections import defaultdict
1010
from itertools import chain
1111

12-
from tensorflow.python.keras.layers import Embedding
12+
from tensorflow.python.keras.layers import Embedding, Lambda
1313
from tensorflow.python.keras.regularizers import l2
1414

1515
from .layers.sequence import SequencePoolingLayer, WeightedSequenceLayer
@@ -138,7 +138,11 @@ def get_dense_input(features, feature_columns):
138138
filter(lambda x: isinstance(x, fc_lib.DenseFeat), feature_columns)) if feature_columns else []
139139
dense_input_list = []
140140
for fc in dense_feature_columns:
141-
dense_input_list.append(features[fc.name])
141+
if fc.transform_fn is None:
142+
dense_input_list.append(features[fc.name])
143+
else:
144+
transform_result = Lambda(fc.transform_fn)(features[fc.name])
145+
dense_input_list.append(transform_result)
142146
return dense_input_list
143147

144148

tests/utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,15 @@ def get_test_data(sample_size=1000, embedding_size=4, sparse_feature_num=1, dens
5252
SparseFeat(prefix + 'sparse_feature_' + str(i), dim, embedding_size, use_hash=hash_flag, dtype=tf.int32,group_name=group_name))
5353

5454
for i in range(dense_feature_num):
55-
feature_columns.append(DenseFeat(prefix + 'dense_feature_' + str(i), 1, dtype=tf.float32))
55+
transform_fn = lambda x: (x - 0.0)/ 1.0
56+
feature_columns.append(
57+
DenseFeat(
58+
prefix + 'dense_feature_' + str(i),
59+
1,
60+
dtype=tf.float32,
61+
transform_fn=transform_fn
62+
)
63+
)
5664
for i, mode in enumerate(sequence_feature):
5765
dim = np.random.randint(1, 10)
5866
maxlen = np.random.randint(1, 10)

0 commit comments

Comments
 (0)