Skip to content

Commit dcf583f

Browse files
author
浅梦
authored
v0.7.1
* Simplify `VarLenSparseFeat`, support setting weight_normalization. * Fix problem of embedding size of `SparseFeat` in `linear_feature_columns`.
1 parent db229dc commit dcf583f

22 files changed

+129
-109
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Steps to reproduce the behavior:
2020
**Operating environment(运行环境):**
2121
- python version [e.g. 3.4, 3.6]
2222
- tensorflow version [e.g. 1.4.0, 1.12.0]
23-
- deepctr version [e.g. 0.5.2,]
23+
- deepctr version [e.g. 0.7.1,]
2424

2525
**Additional context**
2626
Add any other context about the problem here.

.github/ISSUE_TEMPLATE/question.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ Add any other context about the problem here.
1717
**Operating environment(运行环境):**
1818
- python version [e.g. 3.6]
1919
- tensorflow version [e.g. 1.4.0,]
20-
- deepctr version [e.g. 0.5.2,]
20+
- deepctr version [e.g. 0.7.1,]

deepctr/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .utils import check_version
22

3-
__version__ = '0.7.0'
3+
__version__ = '0.7.1'
44
check_version(__version__)

deepctr/inputs.py

+62-47
Original file line numberDiff line numberDiff line change
@@ -14,67 +14,74 @@
1414
from tensorflow.python.keras.regularizers import l2
1515

1616
from .layers.sequence import SequencePoolingLayer, WeightedSequenceLayer
17-
from .layers.utils import Hash, concat_func, Linear,add_func
17+
from .layers.utils import Hash, concat_func, Linear, add_func
1818

1919
DEFAULT_GROUP_NAME = "default_group"
2020

2121

2222
class SparseFeat(namedtuple('SparseFeat',
23-
['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'dtype', 'embedding_name', 'group_name'])):
23+
['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'dtype', 'embedding_name',
24+
'group_name'])):
2425
__slots__ = ()
2526

2627
def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype="int32", embedding_name=None,
2728
group_name=DEFAULT_GROUP_NAME):
2829
if embedding_name is None:
2930
embedding_name = name
3031
if embedding_dim == "auto":
31-
embedding_dim = 6 * int(pow(vocabulary_size, 0.25))
32+
embedding_dim = 6 * int(pow(vocabulary_size, 0.25))
3233
return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, dtype,
3334
embedding_name, group_name)
3435

3536
def __hash__(self):
3637
return self.name.__hash__()
3738

38-
# def __eq__(self, other):
39-
# if self.name == other.name and self.embedding_name == other.embedding_name:
40-
# return True
41-
# return False
4239

43-
# def __repr__(self):
44-
# return 'SparseFeat:'+self.name
40+
class VarLenSparseFeat(namedtuple('VarLenSparseFeat',
41+
['sparsefeat', 'maxlen', 'combiner', 'length_name', 'weight_name', 'weight_norm'])):
42+
__slots__ = ()
4543

44+
def __new__(cls, sparsefeat, maxlen, combiner="mean", length_name=None, weight_name=None, weight_norm=True):
45+
return super(VarLenSparseFeat, cls).__new__(cls, sparsefeat, maxlen, combiner, length_name, weight_name,
46+
weight_norm)
4647

47-
class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype'])):
48-
__slots__ = ()
48+
@property
49+
def name(self):
50+
return self.sparsefeat.name
4951

50-
def __new__(cls, name, dimension=1, dtype="float32"):
51-
return super(DenseFeat, cls).__new__(cls, name, dimension, dtype)
52+
@property
53+
def vocabulary_size(self):
54+
return self.sparsefeat.vocabulary_size
5255

53-
def __hash__(self):
54-
return self.name.__hash__()
56+
@property
57+
def embedding_dim(self):
58+
return self.sparsefeat.embedding_dim
5559

56-
# def __eq__(self, other):
57-
# if self.name == other.name:
58-
# return True
59-
# return False
60+
@property
61+
def use_hash(self):
62+
return self.sparsefeat.use_hash
6063

61-
# def __repr__(self):
62-
# return 'DenseFeat:'+self.name
64+
@property
65+
def dtype(self):
66+
return self.sparsefeat.dtype
6367

68+
@property
69+
def embedding_name(self):
70+
return self.sparsefeat.embedding_name
6471

65-
class VarLenSparseFeat(namedtuple('VarLenFeat',
66-
['name', 'maxlen', 'vocabulary_size', 'embedding_dim', 'combiner', 'use_hash',
67-
'dtype','length_name' ,'weight_name', 'embedding_name', 'group_name'])):
72+
@property
73+
def group_name(self):
74+
return self.sparsefeat.group_name
75+
76+
def __hash__(self):
77+
return self.name.__hash__()
78+
79+
80+
class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype'])):
6881
__slots__ = ()
6982

70-
def __new__(cls, name, maxlen, vocabulary_size, embedding_dim=4, combiner="mean", use_hash=False, dtype="float32",
71-
length_name=None, weight_name=None, embedding_name=None, group_name=DEFAULT_GROUP_NAME):
72-
if embedding_name is None:
73-
embedding_name = name
74-
if embedding_dim == "auto":
75-
embedding_dim = 6 * int(pow(vocabulary_size, 0.25))
76-
return super(VarLenSparseFeat, cls).__new__(cls, name, maxlen, vocabulary_size, embedding_dim, combiner,
77-
use_hash, dtype, length_name,weight_name, embedding_name, group_name)
83+
def __new__(cls, name, dimension=1, dtype="float32"):
84+
return super(DenseFeat, cls).__new__(cls, name, dimension, dtype)
7885

7986
def __hash__(self):
8087
return self.name.__hash__()
@@ -85,7 +92,7 @@ def __hash__(self):
8592
# return False
8693

8794
# def __repr__(self):
88-
# return 'VarLenSparseFeat:'+self.name
95+
# return 'DenseFeat:'+self.name
8996

9097

9198
def get_feature_names(feature_columns):
@@ -111,9 +118,9 @@ def build_input_features(feature_columns, prefix=''):
111118
dtype=fc.dtype)
112119
if fc.weight_name is not None:
113120
input_features[fc.weight_name] = Input(shape=(fc.maxlen, 1), name=prefix + fc.weight_name,
114-
dtype="float32")
121+
dtype="float32")
115122
if fc.length_name is not None:
116-
input_features[fc.length_name] = Input((1,),name=prefix+fc.length_name,dtype='int32')
123+
input_features[fc.length_name] = Input((1,), name=prefix + fc.length_name, dtype='int32')
117124

118125
else:
119126
raise TypeError("Invalid feature column type,got", type(fc))
@@ -123,12 +130,12 @@ def build_input_features(feature_columns, prefix=''):
123130

124131
def create_embedding_dict(sparse_feature_columns, varlen_sparse_feature_columns, init_std, seed, l2_reg,
125132
prefix='sparse_', seq_mask_zero=True):
126-
sparse_embedding = {feat.embedding_name:Embedding(feat.vocabulary_size, feat.embedding_dim,
127-
embeddings_initializer=RandomNormal(
128-
mean=0.0, stddev=init_std, seed=seed),
129-
embeddings_regularizer=l2(l2_reg),
130-
name=prefix + '_emb_' + feat.embedding_name) for feat in sparse_feature_columns}
131-
133+
sparse_embedding = {feat.embedding_name: Embedding(feat.vocabulary_size, feat.embedding_dim,
134+
embeddings_initializer=RandomNormal(
135+
mean=0.0, stddev=init_std, seed=seed),
136+
embeddings_regularizer=l2(l2_reg),
137+
name=prefix + '_emb_' + feat.embedding_name) for feat in
138+
sparse_feature_columns}
132139

133140
if varlen_sparse_feature_columns and len(varlen_sparse_feature_columns) > 0:
134141
for feat in varlen_sparse_feature_columns:
@@ -160,7 +167,7 @@ def get_embedding_vec_list(embedding_dict, input_dict, sparse_feature_columns, r
160167

161168
def create_embedding_matrix(feature_columns, l2_reg, init_std, seed, prefix="", seq_mask_zero=True):
162169
sparse_feature_columns = list(
163-
filter(lambda x: isinstance(x, SparseFeat) , feature_columns)) if feature_columns else []
170+
filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if feature_columns else []
164171
varlen_sparse_feature_columns = list(
165172
filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else []
166173
sparse_emb_dict = create_embedding_dict(sparse_feature_columns, varlen_sparse_feature_columns, init_std, seed,
@@ -170,25 +177,32 @@ def create_embedding_matrix(feature_columns, l2_reg, init_std, seed, prefix="",
170177

171178
def get_linear_logit(features, feature_columns, units=1, use_bias=False, init_std=0.0001, seed=1024, prefix='linear',
172179
l2_reg=0):
180+
for i in range(len(feature_columns)):
181+
if isinstance(feature_columns[i], SparseFeat):
182+
feature_columns[i] = feature_columns[i]._replace(embedding_dim=1)
183+
if isinstance(feature_columns[i], VarLenSparseFeat):
184+
feature_columns[i] = feature_columns[i]._replace(
185+
sparsefeat=feature_columns[i].sparsefeat._replace(embedding_dim=1))
186+
173187
linear_emb_list = [input_from_feature_columns(features, feature_columns, l2_reg, init_std, seed,
174188
prefix=prefix + str(i))[0] for i in range(units)]
175189
_, dense_input_list = input_from_feature_columns(features, feature_columns, l2_reg, init_std, seed, prefix=prefix)
176190

177191
linear_logit_list = []
178192
for i in range(units):
179193

180-
if len(linear_emb_list[0]) > 0 and len(dense_input_list) > 0:
194+
if len(linear_emb_list[i]) > 0 and len(dense_input_list) > 0:
181195
sparse_input = concat_func(linear_emb_list[i])
182196
dense_input = concat_func(dense_input_list)
183197
linear_logit = Linear(l2_reg, mode=2, use_bias=use_bias)([sparse_input, dense_input])
184-
elif len(linear_emb_list[0]) > 0:
198+
elif len(linear_emb_list[i]) > 0:
185199
sparse_input = concat_func(linear_emb_list[i])
186200
linear_logit = Linear(l2_reg, mode=0, use_bias=use_bias)(sparse_input)
187201
elif len(dense_input_list) > 0:
188202
dense_input = concat_func(dense_input_list)
189203
linear_logit = Linear(l2_reg, mode=1, use_bias=use_bias)(dense_input)
190204
else:
191-
#raise NotImplementedError
205+
# raise NotImplementedError
192206
return add_func([])
193207
linear_logit_list.append(linear_logit)
194208

@@ -235,15 +249,15 @@ def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_colu
235249
feature_length_name = fc.length_name
236250
if feature_length_name is not None:
237251
if fc.weight_name is not None:
238-
seq_input = WeightedSequenceLayer()(
252+
seq_input = WeightedSequenceLayer(weight_normalization=fc.weight_norm)(
239253
[embedding_dict[feature_name], features[feature_length_name], features[fc.weight_name]])
240254
else:
241255
seq_input = embedding_dict[feature_name]
242256
vec = SequencePoolingLayer(combiner, supports_masking=False)(
243257
[seq_input, features[feature_length_name]])
244258
else:
245259
if fc.weight_name is not None:
246-
seq_input = WeightedSequenceLayer(supports_masking=True)(
260+
seq_input = WeightedSequenceLayer(weight_normalization=fc.weight_norm, supports_masking=True)(
247261
[embedding_dict[feature_name], features[fc.weight_name]])
248262
else:
249263
seq_input = embedding_dict[feature_name]
@@ -254,6 +268,7 @@ def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_colu
254268
return chain.from_iterable(pooling_vec_list.values())
255269
return pooling_vec_list
256270

271+
257272
def get_dense_input(features, feature_columns):
258273
dense_feature_columns = list(filter(lambda x: isinstance(x, DenseFeat), feature_columns)) if feature_columns else []
259274
dense_input_list = []

deepctr/layers/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def build(self, input_shape):
6666
if input_shape[0][-1] != input_shape[1][-1] or input_shape[0][1] != 1:
6767
raise ValueError('A `LocalActivationUnit` layer requires '
6868
'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)'
69-
'Got different shapes: %s,%s' % (input_shape))
69+
'Got different shapes: %s,%s' % (input_shape[0],input_shape[1]))
7070
size = 4 * \
7171
int(input_shape[0][-1]
7272
) if len(self.hidden_units) == 0 else self.hidden_units[-1]

deepctr/layers/sequence.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ class WeightedSequenceLayer(Layer):
121121
- 3D tensor with shape: ``(batch_size, T, embedding_size)``.
122122
123123
Arguments
124-
- **weight_normalization**: bool.Whether normalize the weight socre before applying to sequence.
124+
- **weight_normalization**: bool.Whether normalize the weight score before applying to sequence.
125125
126126
- **supports_masking**:If True,the input need to support masking.
127127
"""
128128

129-
def __init__(self,weight_normalization=False, supports_masking=False, **kwargs):
129+
def __init__(self,weight_normalization=True, supports_masking=False, **kwargs):
130130
super(WeightedSequenceLayer, self).__init__(**kwargs)
131131
self.weight_normalization = weight_normalization
132132
self.supports_masking = supports_masking

deepctr/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def check(version):
4040
'\nDeepCTR version {0} detected. Your version is {1}.\nUse `pip install -U deepctr` to upgrade.Changelog: https://github.com/shenweichen/DeepCTR/releases/tag/v{0}'.format(
4141
latest_version, version))
4242
except Exception as e:
43-
print(e)
43+
print("Please check the latest version manually on https://pypi.org/project/deepctr/#history")
4444
return
4545

4646
Thread(target=check, args=(version,)).start()

docs/source/Examples.md

+12-10
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,11 @@ if __name__ == "__main__":
246246

247247
use_weighted_sequence = False
248248
if use_weighted_sequence:
249-
varlen_feature_columns = [VarLenSparseFeat('genres', maxlen= max_len,vocabulary_size=len(
250-
key2index) + 1,embedding_dim=4, combiner='mean',weight_name='genres_weight')] # Notice : value 0 is for padding for sequence input feature
249+
varlen_feature_columns = [VarLenSparseFeat(SparseFeat('genres',vocabulary_size=len(
250+
key2index) + 1,embedding_dim=4), maxlen= max_len, combiner='mean',weight_name='genres_weight')] # Notice : value 0 is for padding for sequence input feature
251251
else:
252-
varlen_feature_columns = [VarLenSparseFeat('genres', maxlen=max_len,vocabulary_size= len(
253-
key2index) + 1,embedding_dim=4, combiner='mean',weight_name=None)] # Notice : value 0 is for padding for sequence input feature
252+
varlen_feature_columns = [VarLenSparseFeat(SparseFeat('genres',vocabulary_size= len(
253+
key2index) + 1,embedding_dim=4), maxlen=max_len, combiner='mean',weight_name=None)] # Notice : value 0 is for padding for sequence input feature
254254

255255
linear_feature_columns = fixlen_feature_columns + varlen_feature_columns
256256
dnn_feature_columns = fixlen_feature_columns + varlen_feature_columns
@@ -279,8 +279,8 @@ import numpy as np
279279
import pandas as pd
280280
from tensorflow.python.keras.preprocessing.sequence import pad_sequences
281281

282+
from deepctr.inputs import SparseFeat, VarLenSparseFeat, get_feature_names
282283
from deepctr.models import DeepFM
283-
from deepctr.inputs import SparseFeat, VarLenSparseFeat,get_feature_names
284284

285285
if __name__ == "__main__":
286286
data = pd.read_csv("./movielens_sample.txt")
@@ -301,20 +301,22 @@ if __name__ == "__main__":
301301

302302
# 2.set hashing space for each sparse field and generate feature config for sequence feature
303303

304-
fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique() * 5,embedding_dim=4, use_hash=True, dtype='string')
304+
fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique() * 5, embedding_dim=4, use_hash=True, dtype='string')
305305
for feat in sparse_features]
306-
varlen_feature_columns = [VarLenSparseFeat('genres', maxlen=max_len,vocabulary_size=100,embedding_dim=4,combiner= 'mean', use_hash=True,
307-
dtype="string")] # Notice : value 0 is for padding for sequence input feature
306+
varlen_feature_columns = [
307+
VarLenSparseFeat(SparseFeat('genres', vocabulary_size=100, embedding_dim=4, use_hash=True, dtype="string"),
308+
maxlen=max_len, combiner='mean',
309+
)] # Notice : value 0 is for padding for sequence input feature
308310
linear_feature_columns = fixlen_feature_columns + varlen_feature_columns
309311
dnn_feature_columns = fixlen_feature_columns + varlen_feature_columns
310312
feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns)
311313

312314
# 3.generate input data for model
313-
model_input = {name:data[name] for name in feature_names}
315+
model_input = {name: data[name] for name in feature_names}
314316
model_input['genres'] = genres_list
315317

316318
# 4.Define Model,compile and train
317-
model = DeepFM(linear_feature_columns,dnn_feature_columns, task='regression')
319+
model = DeepFM(linear_feature_columns, dnn_feature_columns, task='regression')
318320

319321
model.compile("adam", "mse", metrics=['mse'], )
320322
history = model.fit(model_input, data[target].values,

docs/source/Features.md

+3-8
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,14 @@ DNN based CTR prediction models usually have following 4 modules:
4242

4343
### VarLenSparseFeat
4444

45-
``VarLenSparseFeat`` is a namedtuple with signature ``VarLenSparseFeat(name, maxlen, vocabulary_size, embedding_dim, combiner,use_hash, dtype, length_name,weight_name, embedding_name, group_name)``
45+
``VarLenSparseFeat`` is a namedtuple with signature ``VarLenSparseFeat(sparsefeat, maxlen, combiner, length_name, weight_name,weight_norm)``
4646

47-
- name : feature name
47+
- sparsefeat : a instance of `SparseFeat`
4848
- maxlen : maximum length of this feature for all samples
49-
- vocabulary_size : number of unique feature values for sprase feature or hashing space when `use_hash=True`
50-
- embedding_dim : embedding dimension
5149
- combiner : pooling method,can be ``sum``,``mean`` or ``max``
52-
- use_hash : defualt `False`.if `True` the input will be hashed to space of size `vocabulary_size`.
53-
- dtype : default `float32`.dtype of input tensor.
5450
- length_name : feature length name,if `None`, value 0 in feature is for padding.
5551
- weight_name : default `None`. If not None, the sequence feature will be multiplyed by the feature whose name is `weight_name`.
56-
- embedding_name : default `None`. If None, the `embedding_name` will be same as `name`.
57-
- group_name : feature group of this feature.
52+
- weight_norm : default `True`. Whether normalize the weight score or not.
5853

5954
## Models
6055

docs/source/History.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# History
2+
- 01/28/2020 : [v0.7.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.7.1) released.Simplify [VarLenSparseFeat](./Features.html#varlensparsefeat),support setting weight_normalization.Fix problem of embedding size of `SparseFeat` in `linear_feature_columns`.
23
- 11/24/2019 : [v0.7.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.7.0) released.Refactor [feature columns](./Features.html#feature-columns).Different features can use different `embedding_dim` and group-wise interaction is available by setting `group_name`.
3-
44
- 11/06/2019 : [v0.6.3](https://github.com/shenweichen/DeepCTR/releases/tag/v0.6.3) released.Add `WeightedSequenceLayer` and support [weighted sequence feature input](./Examples.html#multi-value-input-movielens).
55
- 10/03/2019 : [v0.6.2](https://github.com/shenweichen/DeepCTR/releases/tag/v0.6.2) released.Simplify the input logic.
66
- 09/08/2019 : [v0.6.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.6.1) released.Fix bugs in `CCPM` and `DynamicGRU`.

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# The short X.Y version
2727
version = ''
2828
# The full version, including alpha/beta/rc tags
29-
release = '0.7.0'
29+
release = '0.7.1'
3030

3131

3232
# -- General configuration ---------------------------------------------------

0 commit comments

Comments
 (0)