Skip to content

Commit

Permalink
Fix squeeze_dims warnings in tf.contrib.learn
Browse files Browse the repository at this point in the history
Signed-off-by: Yong Tang <[email protected]>
  • Loading branch information
yongtang committed Apr 22, 2018
1 parent 8257b90 commit 685fec3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tensorflow/contrib/learn/python/learn/estimators/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def _logits_to_predictions(self, logits):
key = prediction_key.PredictionKey.SCORES
with ops.name_scope(None, "predictions", (logits,)):
if self.logits_dimension == 1:
logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key)
logits = array_ops.squeeze(logits, axis=(1,), name=key)
return {key: self._link_fn(logits)}

def _metrics(self, eval_loss, predictions, labels, weights):
Expand Down Expand Up @@ -974,7 +974,7 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None):
is_squeezed_labels = False
# TODO(ptucker): This will break for dynamic shapes.
if len(labels.get_shape()) == 2:
labels = array_ops.squeeze(labels, squeeze_dims=(1,))
labels = array_ops.squeeze(labels, axis=(1,))
is_squeezed_labels = True

loss = nn.sparse_softmax_cross_entropy_with_logits(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/learn/python/learn/ops/losses_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def mean_squared_error_regressor(tensor_in, labels, weights, biases, name=None):
[tensor_in, labels]):
predictions = nn.xw_plus_b(tensor_in, weights, biases)
if len(labels.get_shape()) == 1 and len(predictions.get_shape()) == 2:
predictions = array_ops_.squeeze(predictions, squeeze_dims=[1])
predictions = array_ops_.squeeze(predictions, axis=[1])
return predictions, losses.mean_squared_error(labels, predictions)


Expand Down

0 comments on commit 685fec3

Please sign in to comment.