Skip to content

Commit

Permalink
making int64 cast to float64
Browse files Browse the repository at this point in the history
  • Loading branch information
rupei committed Aug 13, 2020
1 parent 1b0ca2f commit 5e9e66f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _get_sample(current_state, kernel_results):
class ExpectationsReducer(reducer_base.Reducer):
"""`Reducer` that computes a running expectation.
`ExpectationsReducer` calculates expectation over some arbitrary structure
`ExpectationsReducer` calculates expectations over some arbitrary structure
of `transform_fn`s. A `transform_fn` is a function that accepts a Markov
chain sample and kernel results, and outputs the relevant value for
expectation calculation. In other words, if we denote a `transform_fn`
Expand Down
30 changes: 20 additions & 10 deletions tensorflow_probability/python/experimental/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ def __init__(self, shape, event_ndims=None, dtype=tf.float32):
the inner-most dimensions. Specifying `None` returns all cross
product terms (no batching) and is the default.
dtype: Dtype of incoming samples and the resulting statistics.
By default, the dtype is `tf.float32`. Any integer dtypes will also
be treated as `tf.float32` (to not lose significant precision).
By default, the dtype is `tf.float32`. Any integer dtypes will be
cast to corresponding floats (i.e. `tf.int32` will be cast to
`tf.float32`), as intermediate calculations should be performing
floating-point division.
Raises:
ValueError: if `event_ndims` is greater than the rank of the intended
Expand All @@ -160,7 +162,9 @@ def __init__(self, shape, event_ndims=None, dtype=tf.float32):
raise ValueError('`event_ndims` over 13 not supported')
self.shape = shape
self.event_ndims = event_ndims
if dtype.is_integer:
if dtype is tf.int64:
dtype = tf.float64
elif dtype.is_integer:
dtype = tf.float32
self.dtype = dtype

Expand Down Expand Up @@ -271,8 +275,10 @@ def __init__(self, shape=(), dtype=tf.float32):
shape: Python `Tuple` or `TensorShape` representing the shape of
incoming samples. By default, the shape is assumed to be scalar.
dtype: Dtype of incoming samples and the resulting statistics.
By default, the dtype is `tf.float32`. Any integer dtypes will also
be treated as `tf.float32` (to not lose significant precision).
By default, the dtype is `tf.float32`. Any integer dtypes will be
cast to corresponding floats (i.e. `tf.int32` will be cast to
`tf.float32`), as intermediate calculations should be performing
floating-point division.
"""
super(RunningVariance, self).__init__(shape, event_ndims=0, dtype=dtype)

Expand All @@ -287,8 +293,8 @@ class RunningMean(object):
In computation, samples can be provided individually or in chunks. A
"chunk" of size M implies incorporating M samples into a single expectation
computation at once, which is more efficient than one by one. If more than one
callable is accepted and chunking is enabled, the chunked `axis` will define
chunking semantics for all callables.
sample is accepted and chunking is enabled, the chunked `axis` will define
chunking semantics for all samples.
`RunningMean` objects do not hold state information. That information,
which includes intermediate calculations, are held in a
Expand All @@ -307,11 +313,15 @@ def __init__(self, shape, dtype=tf.float32):
shape: Python `Tuple` or `TensorShape` representing the shape of
incoming samples.
dtype: Dtype of incoming samples and the resulting statistics.
By default, the dtype is `tf.float32`. Any integer dtypes will also
be treated as `tf.float32` (to not lose significant precision).
By default, the dtype is `tf.float32`. Any integer dtypes will be
cast to corresponding floats (i.e. `tf.int32` will be cast to
`tf.float32`), as intermediate calculations should be performing
floating-point division.
"""
self.shape = shape
if dtype.is_integer:
if dtype is tf.int64:
dtype = tf.float64
elif dtype.is_integer:
dtype = tf.float32
self.dtype = dtype

Expand Down

0 comments on commit 5e9e66f

Please sign in to comment.