Skip to content

Commit

Permalink
Use common_dtype to ensure we get a base dtype rather than a float32_…
Browse files Browse the repository at this point in the history
…ref.

PiperOrigin-RevId: 272673563
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Oct 3, 2019
1 parent 9f0bea7 commit ddf838d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tensorflow_probability/python/distributions/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def __init__(self,
if (probs is None) == (logits is None):
raise ValueError('Must pass probs or logits, but not both.')
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype([logits, probs], dtype_hint=tf.float32)
self._probs = tensor_util.convert_nonref_to_tensor(
probs, dtype_hint=tf.float32, name='probs')
probs, dtype=dtype, name='probs')
self._logits = tensor_util.convert_nonref_to_tensor(
logits, dtype_hint=tf.float32, name='logits')
logits, dtype=dtype, name='logits')
super(Geometric, self).__init__(
dtype=(self._logits.dtype if self._probs is None
else self._probs.dtype),
dtype=dtype,
reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
Expand Down

0 comments on commit ddf838d

Please sign in to comment.