Skip to content

Commit

Permalink
Fixes for bare-column and column generators as inputs
Browse files Browse the repository at this point in the history
1) Passing a single feature column (not wrapped in a list) should succeed, or fail clearly (In this change I wrap it in a list so it succeeds)

Currently, since concrete `_FeatureColumn` classes are iterable (inherit from named_tuple), the code attempts to interpret the _FeatureColumn object as a list of feature columns.

This can give confusing errors:

    fc.input_layer({'a':[1]},fc.numeric_column('a'))

    ValueError: Items of feature_columns must be a _FeatureColumn. Given (type <class 'str'>): a.

    fc.input_layer(
        {'a':[1]},
        fc.indicator_column(
            fc.categorical_column_with_identity('a', 100)))

    ValueError: Items of feature_columns must be a _DenseColumn. You can wrap a categorical column with an embedding_column or indicator_column. Given: _IdentityCategoricalColumn(key='a', num_buckets=100, default_value=None)

2) Passing an `iterator` should fail, or convert it to a list (so it can be iterated multiple times).

In this change I convert it to a list

Currently it throws an unclear error:

    features = {'a':[1],'b':[2]}
    columns = (fc.numeric_column(key) for key in features)
    fc.input_layer(features,columns)

    ValueError: List argument 'values' to 'ConcatV2' Op with length 0 shorter than minimum length 2.

PiperOrigin-RevId: 169314383
  • Loading branch information
MarkDaoust authored and tensorflower-gardener committed Sep 19, 2017
1 parent d17e073 commit 76293c2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
19 changes: 14 additions & 5 deletions tensorflow/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def input_layer(features,
Raises:
ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
"""
_check_feature_columns(feature_columns)
feature_columns = _clean_feature_columns(feature_columns)
for column in feature_columns:
if not isinstance(column, _DenseColumn):
raise ValueError(
Expand Down Expand Up @@ -294,7 +294,7 @@ def linear_model(features,
ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
nor `_CategoricalColumn`.
"""
_check_feature_columns(feature_columns)
feature_columns = _clean_feature_columns(feature_columns)
for column in feature_columns:
if not isinstance(column, (_DenseColumn, _CategoricalColumn)):
raise ValueError('Items of feature_columns must be either a _DenseColumn '
Expand Down Expand Up @@ -367,7 +367,7 @@ def _transform_features(features, feature_columns):
Returns:
A `dict` mapping `_FeatureColumn` to `Tensor` and `SparseTensor` values.
"""
_check_feature_columns(feature_columns)
feature_columns = _clean_feature_columns(feature_columns)
outputs = {}
with ops.name_scope(
None, default_name='transform_features', values=features.values()):
Expand Down Expand Up @@ -1647,10 +1647,17 @@ def _to_sparse_input(input_tensor, ignore_value=None):
return sparse_tensor_lib.SparseTensor(indices, values, dense_shape)


def _check_feature_columns(feature_columns):
"""Verifies feature_columns input."""
def _clean_feature_columns(feature_columns):
"""Verifies and normalizes `feature_columns` input."""
if isinstance(feature_columns, _FeatureColumn):
feature_columns = [feature_columns]

if isinstance(feature_columns, collections.Iterator):
feature_columns = list(feature_columns)

if isinstance(feature_columns, dict):
raise ValueError('Expected feature_columns to be iterable, found dict.')

for column in feature_columns:
if not isinstance(column, _FeatureColumn):
raise ValueError('Items of feature_columns must be a _FeatureColumn. '
Expand All @@ -1668,6 +1675,8 @@ def _check_feature_columns(feature_columns):
name_to_column[column.name]))
name_to_column[column.name] = column

return feature_columns


class _NumericColumn(_DenseColumn,
collections.namedtuple('_NumericColumn', [
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/feature_column/feature_column_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,21 @@ def test_does_not_support_dict_columns(self):
fc.input_layer(
features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})

def test_bare_column(self):
with ops.Graph().as_default():
features = features = {'a': [0.]}
net = fc.input_layer(features, fc.numeric_column('a'))
with _initialized_session():
self.assertAllClose([[0.]], net.eval())

def test_column_generator(self):
with ops.Graph().as_default():
features = features = {'a': [0.], 'b': [1.]}
columns = (fc.numeric_column(key) for key in features)
net = fc.input_layer(features, columns)
with _initialized_session():
self.assertAllClose([[0., 1.]], net.eval())

def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
Expand Down

0 comments on commit 76293c2

Please sign in to comment.