Skip to content

Commit

Permalink
Fix list comprehension for bucket features in tfx template.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 312005513
  • Loading branch information
jiyongjung authored and tfx-copybara committed May 18, 2020
1 parent 78e4c70 commit 62ba786
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions tfx/experimental/templates/taxi/models/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@

# Name of features which have continuous float values. These features will be
# bucketized using `tft.bucketize`, and will be used as categorical features.
BUCKET_FEATURE_KEYS = []
BUCKET_FEATURE_KEYS = ['pickup_latitude']
# Number of buckets used by tf.transform for encoding each feature. The length
# of this list should be the same with BUCKET_FEATURE_KEYS.
BUCKET_FEATURE_BUCKET_COUNT = []
BUCKET_FEATURE_BUCKET_COUNT = [10]

# Name of features which have categorical values which are mapped to integers.
# These features will be used as categorical features.
Expand Down
7 changes: 4 additions & 3 deletions tfx/experimental/templates/taxi/models/keras/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ def _build_keras_model(hidden_units, learning_rate):
categorical_columns += [
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
key,
num_buckets=features.BUCKET_FEATURE_BUCKET_COUNT,
default_value=0)
for key in features.transformed_names(features.BUCKET_FEATURE_KEYS)
num_buckets=num_buckets,
default_value=0) for key, num_buckets in zip(
features.transformed_names(features.BUCKET_FEATURE_KEYS),
features.BUCKET_FEATURE_BUCKET_COUNT)
]
categorical_columns += [
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
Expand Down
4 changes: 2 additions & 2 deletions tfx/experimental/templates/taxi/models/keras/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class ModelTest(tf.test.TestCase):
def testBuildKerasModel(self):
built_model = model._build_keras_model(
hidden_units=[1, 1], learning_rate=0.1) # pylint: disable=protected-access
self.assertEqual(len(built_model.layers), 8)
self.assertEqual(len(built_model.layers), 9)

built_model = model._build_keras_model(hidden_units=[1], learning_rate=0.1) # pylint: disable=protected-access
self.assertEqual(len(built_model.layers), 7)
self.assertEqual(len(built_model.layers), 8)


if __name__ == '__main__':
Expand Down

0 comments on commit 62ba786

Please sign in to comment.