|
| 1 | +# pylint: disable=g-bad-file-header |
| 2 | +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# ============================================================================== |
| 16 | +"""Utility functions relating DataFrames to Estimators.""" |
| 17 | + |
| 18 | +from __future__ import absolute_import |
| 19 | +from __future__ import division |
| 20 | +from __future__ import print_function |
| 21 | + |
| 22 | +from tensorflow.contrib.layers import feature_column |
| 23 | +from tensorflow.contrib.learn.python.learn.dataframe import series as ss |
| 24 | +from tensorflow.python.framework import ops |
| 25 | +from tensorflow.python.ops import parsing_ops |
| 26 | + |
| 27 | + |
| 28 | +def _to_feature_spec(tensor, default_value=None): |
| 29 | + if isinstance(tensor, ops.SparseTensor): |
| 30 | + return parsing_ops.VarLenFeature(dtype=tensor.dtype) |
| 31 | + else: |
| 32 | + return parsing_ops.FixedLenFeature(shape=tensor.get_shape(), |
| 33 | + dtype=tensor.dtype, |
| 34 | + default_value=default_value) |
| 35 | + |
| 36 | + |
| 37 | +def _infer_feature_specs(dataframe, keys_with_defaults): |
| 38 | + with ops.Graph().as_default(): |
| 39 | + tensors = dataframe.build() |
| 40 | + feature_specs = { |
| 41 | + name: _to_feature_spec(tensor, keys_with_defaults.get(name)) |
| 42 | + for name, tensor in tensors.items()} |
| 43 | + return feature_specs |
| 44 | + |
| 45 | + |
| 46 | +def _build_alternate_universe( |
| 47 | + dataframe, base_input_keys_with_defaults, feature_keys): |
| 48 | + """Create an alternate universe assuming that the base series are defined. |
| 49 | +
|
| 50 | + The resulting graph will be used with an `input_fn` that provides exactly |
| 51 | + those features. |
| 52 | +
|
| 53 | + Args: |
| 54 | + dataframe: the underlying `DataFrame` |
| 55 | + base_input_keys_with_defaults: a `dict` from the names of columns to |
| 56 | + considered base features to their default values. |
| 57 | + feature_keys: the names of columns to be used as features (including base |
| 58 | + features and derived features). |
| 59 | +
|
| 60 | + Returns: |
| 61 | + A `dict` mapping names to rebuilt `Series`. |
| 62 | + """ |
| 63 | + feature_specs = _infer_feature_specs(dataframe, base_input_keys_with_defaults) |
| 64 | + |
| 65 | + alternate_universe_map = { |
| 66 | + dataframe[name]: ss.PredefinedSeries(name, feature_specs[name]) |
| 67 | + for name in base_input_keys_with_defaults.keys() |
| 68 | + } |
| 69 | + |
| 70 | + def _in_alternate_universe(orig_series): |
| 71 | + # pylint: disable=protected-access |
| 72 | + # Map Series in the original DataFrame to series rebuilt assuming base_keys. |
| 73 | + try: |
| 74 | + return alternate_universe_map[orig_series] |
| 75 | + except KeyError: |
| 76 | + rebuilt_inputs = [] |
| 77 | + for i in orig_series._input_series: |
| 78 | + rebuilt_inputs.append(_in_alternate_universe(i)) |
| 79 | + rebuilt_series = ss.TransformedSeries(rebuilt_inputs, |
| 80 | + orig_series._transform, |
| 81 | + orig_series._output_name) |
| 82 | + alternate_universe_map[orig_series] = rebuilt_series |
| 83 | + return rebuilt_series |
| 84 | + |
| 85 | + orig_feature_series_dict = {fk: dataframe[fk] for fk in feature_keys} |
| 86 | + new_feature_series_dict = ({name: _in_alternate_universe(x) |
| 87 | + for name, x in orig_feature_series_dict.items()}) |
| 88 | + return new_feature_series_dict, feature_specs |
| 89 | + |
| 90 | + |
| 91 | +def to_feature_columns_and_input_fn(dataframe, |
| 92 | + base_input_keys_with_defaults, |
| 93 | + feature_keys, |
| 94 | + target_keys=None): |
| 95 | + """Build a list of FeatureColumns and an input_fn for use with Estimator. |
| 96 | +
|
| 97 | + Args: |
| 98 | + dataframe: the underlying dataframe |
| 99 | + base_input_keys_with_defaults: a dict from the names of columns to be |
| 100 | + considered base features to their default values. These columns will be |
| 101 | + fed via input_fn. |
| 102 | + feature_keys: the names of columns from which to generate FeatureColumns. |
| 103 | + These may include base features and/or derived features. |
| 104 | + target_keys: the names of columns to be used as targets. None is |
| 105 | + acceptable for unsupervised learning. |
| 106 | +
|
| 107 | + Returns: |
| 108 | + A tuple of two elements: |
| 109 | + * A list of `FeatureColumn`s to be used when constructing an Estimator |
| 110 | + * An input_fn, i.e. a function that returns a pair of dicts |
| 111 | + (features, targets), each mapping string names to Tensors. |
| 112 | + the feature dict provides mappings for all the base columns required |
| 113 | + by the FeatureColumns. |
| 114 | +
|
| 115 | + Raises: |
| 116 | + ValueError: when the feature and target key sets are non-disjoint, or the |
| 117 | + base_input and target sets are non-disjoint. |
| 118 | + """ |
| 119 | + if feature_keys is None or not feature_keys: |
| 120 | + raise ValueError("feature_keys must be specified.") |
| 121 | + |
| 122 | + if target_keys is None: |
| 123 | + target_keys = [] |
| 124 | + |
| 125 | + base_input_keys = base_input_keys_with_defaults.keys() |
| 126 | + |
| 127 | + in_two = (set(feature_keys) & set(target_keys)) or (set(base_input_keys) & |
| 128 | + set(target_keys)) |
| 129 | + if in_two: |
| 130 | + raise ValueError("Columns cannot be used for both features and targets: %s" |
| 131 | + % ", ".join(in_two)) |
| 132 | + |
| 133 | + # Obtain the feature series in the alternate universe |
| 134 | + new_feature_series_dict, feature_specs = _build_alternate_universe( |
| 135 | + dataframe, base_input_keys_with_defaults, feature_keys) |
| 136 | + |
| 137 | + # TODO(soergel): Allow non-real, non-dense DataFrameColumns |
| 138 | + for key in new_feature_series_dict.keys(): |
| 139 | + spec = feature_specs[key] |
| 140 | + if not ( |
| 141 | + isinstance(spec, parsing_ops.FixedLenFeature) |
| 142 | + and (spec.dtype.is_integer or spec.dtype.is_floating)): |
| 143 | + raise ValueError("For now, only real dense columns can be passed from " |
| 144 | + "DataFrame to Estimator. %s is %s of %s" % ( |
| 145 | + (key, type(spec).__name__, spec.dtype))) |
| 146 | + |
| 147 | + # Make FeatureColumns from these |
| 148 | + feature_columns = [feature_column.DataFrameColumn(name, s) |
| 149 | + for name, s in new_feature_series_dict.items()] |
| 150 | + |
| 151 | + # Make a new DataFrame with only the Series needed for input_fn. |
| 152 | + # This is important to avoid starting queue feeders that won't be used. |
| 153 | + limited_dataframe = dataframe.select_columns( |
| 154 | + list(base_input_keys) + list(target_keys)) |
| 155 | + |
| 156 | + # Build an input_fn suitable for use with Estimator. |
| 157 | + def input_fn(): |
| 158 | + # It's important to build all the tensors together in one DataFrame. |
| 159 | + # If we did df.select() for both key sets and then build those, the two |
| 160 | + # resulting DataFrames would be shuffled independently. |
| 161 | + tensors = limited_dataframe.build() |
| 162 | + |
| 163 | + base_input_features = {key: tensors[key] for key in base_input_keys} |
| 164 | + targets = {key: tensors[key] for key in target_keys} |
| 165 | + |
| 166 | + # TODO(soergel): Remove this special case when b/30367437 is fixed. |
| 167 | + if len(targets) == 1: |
| 168 | + targets = list(targets.values())[0] |
| 169 | + |
| 170 | + return base_input_features, targets |
| 171 | + |
| 172 | + return feature_columns, input_fn |
0 commit comments