Skip to content

Commit

Permalink
Cast x to logits.dtype in onehot_categorical.
Browse files Browse the repository at this point in the history
Otherwise, dist.log_prob(dist.sample()) has an error in reduce_logsumexp inside _assert_valid_sample when validate_args=True (TypeError: Value passed to parameter 'x' has DataType int32 not in list of allowed values: bfloat16, float16, float32, float64).

PiperOrigin-RevId: 239992007
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Mar 24, 2019
1 parent 17dc407 commit 50c8b9e
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def _sample_n(self, n, seed=None):
return ret

def _log_prob(self, x):
x = tf.cast(x, self.logits.dtype)
x = self._assert_valid_sample(x)
# broadcast logits or x if need be.
logits = self.logits
Expand Down Expand Up @@ -227,7 +228,7 @@ def _assert_valid_sample(self, x):
return distribution_util.with_dependencies([
tf.compat.v1.assert_non_positive(x),
tf.compat.v1.assert_near(
tf.zeros([], dtype=self.dtype),
tf.zeros([], dtype=self.logits.dtype),
tf.reduce_logsumexp(input_tensor=x, axis=[-1])),
], x)

Expand Down

0 comments on commit 50c8b9e

Please sign in to comment.