Skip to content

Commit

Permalink
Stop doing an in-place update in sample_stats.py. This has issues wit…
Browse files Browse the repository at this point in the history
…h the

numpy backend.

Also use self.assertAllClose for some array comparisons (rahter than assertAllEqual). The numpy backend doesn't get exact results.

PiperOrigin-RevId: 335923604
  • Loading branch information
langmore authored and tensorflower-gardener committed Oct 7, 2020
1 parent 283f014 commit ea5351d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def covariance(x,
with tf.name_scope(name or 'covariance'):
x = tf.convert_to_tensor(x, name='x')
# Covariance *only* uses the centered versions of x (and y).
x -= tf.reduce_mean(x, axis=sample_axis, keepdims=True)
x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True)

if y is None:
y = x
Expand All @@ -360,7 +360,7 @@ def covariance(x,
# If x and y have different shape, sample_axis and event_axis will likely
# be wrong for one of them!
tensorshape_util.assert_is_compatible_with(x.shape, y.shape)
y -= tf.reduce_mean(y, axis=sample_axis, keepdims=True)
y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True)

if event_axis is None:
return tf.reduce_mean(
Expand Down Expand Up @@ -821,6 +821,6 @@ def _squeeze(x, axis):
if axis is None:
return tf.squeeze(x, axis=None)
axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32)
axis += ps.zeros([1], dtype=axis.dtype) # Make axis at least 1d.
axis = axis + ps.zeros([1], dtype=axis.dtype) # Make axis at least 1d.
keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis)
return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
10 changes: 5 additions & 5 deletions tensorflow_probability/python/stats/sample_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def test_batch_vector_sampaxis0_eventaxisn1(self):
cov_kd = tfp.stats.covariance(x, y, event_axis=-1, keepdims=True)
self.assertAllEqual((1, 3, 2, 2), cov_kd.shape)
cov_kd = self.evaluate(cov_kd)
self.assertAllEqual(cov, cov_kd[0, ...])
self.assertAllClose(cov, cov_kd[0, ...])

for i in range(3): # Iterate over batch index.
x_i = x[:, i, :] # Pick out ith batch of samples.
Expand All @@ -325,7 +325,7 @@ def test_batch_vector_sampaxis13_eventaxis2(self):
x, y, sample_axis=[1, 3], event_axis=[2], keepdims=True)
self.assertAllEqual((4, 1, 2, 2, 1), cov_kd.shape)
cov_kd = self.evaluate(cov_kd)
self.assertAllEqual(cov, cov_kd[:, 0, :, :, 0])
self.assertAllClose(cov, cov_kd[:, 0, :, :, 0])

for i in range(4): # Iterate over batch index.
# Get ith batch of samples, and permute/reshape to [n_samples, n_events]
Expand All @@ -352,7 +352,7 @@ def test_batch_vector_sampaxis02_eventaxis1(self):
x, y, sample_axis=[0, 2], event_axis=[1], keepdims=True)
self.assertAllEqual((1, 3, 3, 1, 5), cov_kd.shape)
cov_kd = self.evaluate(cov_kd)
self.assertAllEqual(cov, cov_kd[0, :, :, 0, :])
self.assertAllClose(cov, cov_kd[0, :, :, 0, :])

for i in range(5): # Iterate over batch index.
# Get ith batch of samples, and permute/reshape to [n_samples, n_events]
Expand Down Expand Up @@ -383,7 +383,7 @@ def test_batch_vector_sampaxis03_eventaxis12_dynamic(self):
x_ph, y_ph, sample_axis=[0, 3], event_axis=[1, 2], keepdims=True)
cov_kd = self.evaluate(cov_kd)
self.assertAllEqual((1, 3, 4, 3, 4, 1, 6), cov_kd.shape)
self.assertAllEqual(cov, cov_kd[0, :, :, :, :, 0, :])
self.assertAllClose(cov, cov_kd[0, :, :, :, :, 0, :])

for i in range(6): # Iterate over batch index.
# Get ith batch of samples, and permute/reshape to [n_samples, n_events]
Expand Down Expand Up @@ -476,7 +476,7 @@ def test_batch_vector_sampaxis0_eventaxisn1(self):
corr_kd = tfp.stats.correlation(x, y, event_axis=-1, keepdims=True)
self.assertAllEqual((1, 3, 2, 2), corr_kd.shape)
corr_kd = self.evaluate(corr_kd)
self.assertAllEqual(corr, corr_kd[0, ...])
self.assertAllClose(corr, corr_kd[0, ...])

for i in range(3): # Iterate over batch index.
x_i = x[:, i, :] # Pick out ith batch of samples.
Expand Down

0 comments on commit ea5351d

Please sign in to comment.