Skip to content

Commit

Permalink
Fix broadcasting of event and samples in Empirical.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 238975493
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Mar 18, 2019
1 parent bde41b6 commit c3f7271
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tensorflow_probability/python/distributions/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@

def _broadcast_event_and_samples(event, samples, event_ndims):
"""Broadcasts the event or samples."""
samples_shape = tf.shape(input=samples)[:-event_ndims - 1]
# This is the shape of self.samples, without the samples axis, i.e. the shape
# of the result of a call to dist.sample(). This way we can broadcast it with
# event to get a properly-sized event, then add the singleton dim back at
# -event_ndims - 1.
samples_shape = tf.concat(
[tf.shape(input=samples)[:-event_ndims - 1],
tf.shape(input=samples)[tf.rank(samples) - event_ndims:]],
axis=0)
event *= tf.ones(samples_shape, dtype=event.dtype)
event = tf.expand_dims(event, axis=-event_ndims-1)
event = tf.expand_dims(event, axis=-event_ndims - 1)
samples *= tf.ones_like(event, dtype=samples.dtype)

return event, samples
Expand Down Expand Up @@ -259,7 +266,7 @@ def _cdf(self, event):
cdf = tf.reduce_sum(
input_tensor=tf.cast(
tf.reduce_all(
input_tensor=samples - event <= 0,
input_tensor=samples <= event,
axis=tf.range(-self._event_ndims, 0)),
dtype=tf.int32),
axis=-1) / self.num_samples
Expand Down
12 changes: 12 additions & 0 deletions tensorflow_probability/python/distributions/empirical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,18 @@ def testVarianceAndStd(self):
self.assertAllClose(self.evaluate(dist.stddev()),
np.sqrt(expected_variance))

def testLogProbAfterSlice(self):
samples = np.random.randn(6, 5, 4)
dist = empirical.Empirical(samples=samples, event_ndims=1)
self.assertAllEqual((6,), dist.batch_shape)
self.assertAllEqual((4,), dist.event_shape)
sliced_dist = dist[:, tf.newaxis]
samples = self.evaluate(dist.sample())
self.assertAllEqual((6, 4), samples.shape)
lp, sliced_lp = self.evaluate([
dist.log_prob(samples), sliced_dist.log_prob(samples[:, tf.newaxis])])
self.assertAllEqual(lp[:, tf.newaxis], sliced_lp)


class EmpiricalScalarStaticShapeTest(
EmpiricalScalarTest, tf.test.TestCase):
Expand Down

0 comments on commit c3f7271

Please sign in to comment.