Skip to content
This repository has been archived by the owner on Jan 23, 2022. It is now read-only.

Commit

Permalink
Improve accuracy of binomial log_prob for large counts and small prob…
Browse files Browse the repository at this point in the history
…abilities.

The largest accuracy gain is due to using the new lbeta.  There is
also a considerable additional gain by not converting input probs to
logits-space and back.

Because Binomial is a rejection sampler, this change affects sampling too.

PiperOrigin-RevId: 305297126
  • Loading branch information
axch authored and tensorflower-gardener committed Apr 7, 2020
1 parent 46456f7 commit 62cf00a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:samplers",
"//tensorflow_probability/python/internal:tensor_util",
"//tensorflow_probability/python/internal:tensorshape_util",
"//tensorflow_probability/python/math:random_ops",
"//tensorflow_probability/python/math",
],
)

Expand Down
27 changes: 19 additions & 8 deletions tensorflow_probability/python/distributions/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow_probability.python import math as tfp_math
from tensorflow_probability.python.distributions import distribution
from tensorflow_probability.python.distributions import exponential
from tensorflow_probability.python.internal import assert_util
Expand All @@ -32,7 +33,6 @@
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.math.random_ops import random_rademacher


_binomial_sample_note = """
Expand Down Expand Up @@ -157,7 +157,7 @@ def proposal(seed):

unsigned_offsets = tf.where(on_top_mask, top_offsets, exponential_offsets)
offsets = tf.round(
random_rademacher(
tfp_math.random_rademacher(
mode_shape, seed=random_rademacher_seed, dtype=dtype) *
unsigned_offsets)

Expand Down Expand Up @@ -333,9 +333,11 @@ def _event_shape(self):

@distribution_util.AppendDocstring(_binomial_sample_note)
def _log_prob(self, counts):
logits = self._logits_parameter_no_checks()
total_count = tf.convert_to_tensor(self.total_count)
unnorm = _log_unnormalized_prob(logits, counts, total_count)
if self._logits is not None:
unnorm = _log_unnormalized_prob_logits(self._logits, counts, total_count)
else:
unnorm = _log_unnormalized_prob_probs(self._probs, counts, total_count)
norm = _log_normalization(counts, total_count)
return unnorm - norm

Expand Down Expand Up @@ -445,16 +447,25 @@ def _sample_control_dependencies(self, counts):
return assertions


def _log_unnormalized_prob(logits, counts, total_count):
"""Log unnormalized probability."""
def _log_unnormalized_prob_logits(logits, counts, total_count):
"""Log unnormalized probability from logits."""
logits = tf.convert_to_tensor(logits)
return (-tf.math.multiply_no_nan(tf.math.softplus(-logits), counts) -
tf.math.multiply_no_nan(
tf.math.softplus(logits), total_count - counts))


def _log_unnormalized_prob_probs(probs, counts, total_count):
"""Log unnormalized probability from probs."""
probs = tf.convert_to_tensor(probs)
return (tf.math.multiply_no_nan(tf.math.log(probs), counts) +
tf.math.multiply_no_nan(
tf.math.log1p(-probs), total_count - counts))


def _log_normalization(counts, total_count):
return (tf.math.lgamma(1. + total_count - counts) +
tf.math.lgamma(1. + counts) - tf.math.lgamma(1. + total_count))
return (tfp_math.lbeta(1. + counts, 1. + total_count - counts) +
tf.math.log(1. + total_count))


def _maybe_broadcast(a, b):
Expand Down
27 changes: 27 additions & 0 deletions tensorflow_probability/python/distributions/binomial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,33 @@ def testSampleExtremeValues(self):
total_count * probs))
self.assertAllEqual(samples, expected_samples)

def testSampleWithLargeCountAndSmallProbs(self):
total_count = tf.constant(1e6, dtype=tf.float32)
probs = tf.constant(1e-6, dtype=tf.float32)
dist = tfd.Binomial(total_count=total_count,
probs=probs, validate_args=True)
small_probs = self.evaluate(dist.prob([0., 1., 2.]))
# Approximate analytic probabilities for very small counts for this
# binomial:
# - p(0) = (1 - 1e-6)**1e6 \approx 1/e
# - p(1) = p(0) * 1e6 * 1e-6 * 1/(1 - 1e-6)
# \approx 1/e * 1/(1 - 1e-6) \approx 1/e
# - p(2) = p(0) * 1e6**2 * choose(2, 1e6) / (1 - 1e-6)**2
# \approx 1/e * 1e6 * 1/(1e6 - 1) * 1/2 * 1/(1 - 1e-6)**2
# \approx 1/(2*e)
expected_probs = [np.exp(-1.), np.exp(-1.), np.exp(-1.) / 2.]
self.assertAllClose(small_probs, expected_probs)
small_log_probs = self.evaluate(dist.log_prob([0., 1., 2.]))
expected_log_probs = [-1., -1., -1. - np.log(2)]
self.assertAllClose(small_log_probs, expected_log_probs, rtol=5e-6)

n = 1000
sample_avg = self.evaluate(
tf.reduce_mean(dist.sample(n, seed=test_util.test_seed())))
# `sample_avg` has mean 1 and variance ~1/n. So `sample_avg` should be
# within `4.9 / sqrt(n)` of 1 with probability `1 - 1e-6`.
self.assertAllClose(sample_avg, 1., atol=4.9/np.sqrt(n))

def testSampleWithZeroCountsAndNanSuccessProbability(self):
# With zero counts, should sample 0 successes even if success probability is
# nan; and should not interfere with the rest of the batch.
Expand Down

0 comments on commit 62cf00a

Please sign in to comment.