Skip to content

Commit

Permalink
Use float64 instead of float32 in tfd.Beta and tfd.Binomial tests to …
Browse files Browse the repository at this point in the history
…work around a breakage in OSS following scipy 1.7.2 release. Unpin the scipy version.

PiperOrigin-RevId: 408643791
  • Loading branch information
emilyfertig authored and tensorflower-gardener committed Nov 9, 2021
1 parent ed5f42e commit cd35db9
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion tensorflow_probability/python/distributions/beta_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,10 @@ def testBetaSampleMultidimensional(self):
sample_values = self.evaluate(samples)
self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
self.assertFalse(np.any(sample_values < 0.0))
# Pass f64 values to avoid errors in scipy.
self.assertAllClose(
sample_values[:, 1, :].mean(axis=0),
sp_stats.beta.mean(a, b)[1, :],
sp_stats.beta.mean(a.astype(np.float64), b.astype(np.float64))[1, :],
atol=1e-1)

@parameterized.parameters((np.float32, 5e-3), (np.float64, 1e-4))
Expand Down
7 changes: 5 additions & 2 deletions tensorflow_probability/python/distributions/binomial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,14 @@ def testSampleUnbiasedVectorBatch(self):
sample_variance,
])
self.assertAllEqual([4], sample_mean.shape)
# Pass f64 values to avoid errors in scipy.
self.assertAllClose(
stats.binom.mean(counts, probs), sample_mean_, atol=0., rtol=0.10)
stats.binom.mean(counts.astype(np.float64), probs.astype(np.float64)),
sample_mean_, atol=0., rtol=0.10)
self.assertAllEqual([4], sample_variance.shape)
self.assertAllClose(
stats.binom.var(counts, probs), sample_variance_, atol=0., rtol=0.20)
stats.binom.var(counts.astype(np.float64), probs.astype(np.float64)),
sample_variance_, atol=0., rtol=0.20)

def testSampleExtremeValues(self):
total_count = tf.constant(17., dtype=tf.float32)
Expand Down
3 changes: 1 addition & 2 deletions testing/dependency_install_lib.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ install_common_packages() {
install_test_only_packages() {
# The following unofficial dependencies are used only by tests.
PIP_FLAGS=${1-}
# TODO(b/205627112): Unpin scipy version.
python -m pip install $PIP_FLAGS hypothesis matplotlib mock mpmath scipy==1.7.1 pandas
python -m pip install $PIP_FLAGS hypothesis matplotlib mock mpmath scipy pandas
}

dump_versions() {
Expand Down

0 comments on commit cd35db9

Please sign in to comment.