Skip to content

Commit

Permalink
Bugfix: Improve numerical stability of `tf.contrib.distributions.Nega…
Browse files Browse the repository at this point in the history
…tiveBinomial.log_prob`.

PiperOrigin-RevId: 173474795
  • Loading branch information
Joshua V. Dillon authored and tensorflower-gardener committed Oct 26, 2017
1 parent ee501c6 commit 35ca57d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,28 @@ def testNegativeBinomialSample(self):
atol=0.,
rtol=.02)

def testLogProbOverflow(self):
with self.test_session() as sess:
logits = np.float32([20., 30., 40.])
total_count = np.float32(1.)
x = np.float32(0.)
nb = negative_binomial.NegativeBinomial(
total_count=total_count, logits=logits)
log_prob_ = sess.run(nb.log_prob(x))
self.assertAllEqual(np.ones_like(log_prob_, dtype=np.bool),
np.isfinite(log_prob_))

def testLogProbUnderflow(self):
with self.test_session() as sess:
logits = np.float32([-90, -100, -110])
total_count = np.float32(1.)
x = np.float32(0.)
nb = negative_binomial.NegativeBinomial(
total_count=total_count, logits=logits)
log_prob_ = sess.run(nb.log_prob(x))
self.assertAllEqual(np.ones_like(log_prob_, dtype=np.bool),
np.isfinite(log_prob_))


if __name__ == "__main__":
test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def _log_prob(self, x):
def _log_unnormalized_prob(self, x):
if self.validate_args:
x = distribution_util.embed_check_nonnegative_integer_form(x)
return (self.total_count * math_ops.log1p(-self.probs)
+ x * math_ops.log(self.probs))
return (self.total_count * math_ops.log_sigmoid(-self.logits)
+ x * math_ops.log_sigmoid(self.logits))

def _log_normalization(self, x):
if self.validate_args:
Expand Down

0 comments on commit 35ca57d

Please sign in to comment.