Skip to content
This repository has been archived by the owner on Jan 23, 2022. It is now read-only.

Commit

Permalink
Speed up Dirichlet multinomial test by reducing total work, and
Browse files Browse the repository at this point in the history
folding the remainder into the batch dimension.

PiperOrigin-RevId: 305078865
  • Loading branch information
axch authored and tensorflower-gardener committed Apr 6, 2020
1 parent fff9f3c commit 88bddf1
Showing 1 changed file with 3 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -474,15 +474,13 @@ def testSamplesHaveCorrectTotalCounts(self):
seed_stream = test_util.test_seed_stream()
concentration = 1. + 2. * tf.random.uniform(
shape=[4], dtype=np.float32, seed=seed_stream())
total_count = tf.constant(list(range(1000)), dtype=np.float32)
total_count = tf.constant(list(range(int(1e4))), dtype=np.float32)
dist = tfd.DirichletMultinomial(
total_count=total_count,
concentration=concentration,
validate_args=True)
n = int(1e2)
x = dist.sample(n, seed=seed_stream())
for i in range(n):
self.assertAllEqual(tf.reduce_sum(x[i, ...], axis=-1), total_count)
x = dist.sample(seed=seed_stream())
self.assertAllEqual(tf.reduce_sum(x, axis=-1), total_count)


@test_util.test_all_tf_execution_regimes
Expand Down

0 comments on commit 88bddf1

Please sign in to comment.