Skip to content

Commit

Permalink
Make sure low and high have matching dtype. (So that `tfb.Sigmoid(low…
Browse files Browse the repository at this point in the history
…=0., high=5)` works.)

PiperOrigin-RevId: 294926176
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Feb 13, 2020
1 parent 8c019ce commit 6e97957
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tensorflow_probability/python/bijectors/sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import tensor_util


Expand Down Expand Up @@ -74,8 +75,9 @@ def __init__(self, low=None, high=None, validate_args=False, name='sigmoid'):
'Either both or neither of `low` and `high` must be passed. '
'Received `low={}`, `high={}`'.format(low, high))

self._low = tensor_util.convert_nonref_to_tensor(low)
self._high = tensor_util.convert_nonref_to_tensor(high)
dtype = dtype_util.common_dtype([low, high], dtype_hint=tf.float32)
self._low = tensor_util.convert_nonref_to_tensor(low, dtype=dtype)
self._high = tensor_util.convert_nonref_to_tensor(high, dtype=dtype)
super(Sigmoid, self).__init__(
forward_min_event_ndims=0,
validate_args=validate_args,
Expand Down

0 comments on commit 6e97957

Please sign in to comment.