diff --git a/tensorflow_probability/python/distributions/dirichlet_multinomial_test.py b/tensorflow_probability/python/distributions/dirichlet_multinomial_test.py index d2fb47a977..18227e5852 100644 --- a/tensorflow_probability/python/distributions/dirichlet_multinomial_test.py +++ b/tensorflow_probability/python/distributions/dirichlet_multinomial_test.py @@ -474,15 +474,13 @@ def testSamplesHaveCorrectTotalCounts(self): seed_stream = test_util.test_seed_stream() concentration = 1. + 2. * tf.random.uniform( shape=[4], dtype=np.float32, seed=seed_stream()) - total_count = tf.constant(list(range(1000)), dtype=np.float32) + total_count = tf.constant(list(range(int(1e4))), dtype=np.float32) dist = tfd.DirichletMultinomial( total_count=total_count, concentration=concentration, validate_args=True) - n = int(1e2) - x = dist.sample(n, seed=seed_stream()) - for i in range(n): - self.assertAllEqual(tf.reduce_sum(x[i, ...], axis=-1), total_count) + x = dist.sample(seed=seed_stream()) + self.assertAllEqual(tf.reduce_sum(x, axis=-1), total_count) @test_util.test_all_tf_execution_regimes