Skip to content

Commit

Permalink
Add Chi2-Chi2 KL divergence unit tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 227063838
  • Loading branch information
Googler authored and tensorflower-gardener committed Dec 27, 2018
1 parent 11d13b6 commit 661432e
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tensorflow_probability/python/distributions/chi2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,36 @@ def testChi2WithAbsDf(self):
self.assertAllClose(
self.evaluate(tf.floor(tf.abs(df_v))), self.evaluate(chi2.df))

def testChi2Chi2KL(self):
a_df = np.arange(1.0, 10.0)
b_df = np.arange(1.0, 10.0)

# This reshape is intended to expand the number of test cases.
a_df = a_df.reshape((len(a_df), 1))
b_df = b_df.reshape((1, len(b_df)))

a = tfd.Chi2(df=a_df)
b = tfd.Chi2(df=b_df)

# Consistent with
# http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf, page 110
# Evaluating this in numpy is complicated by the fact that there is no
# native implementation of the digamma function.
true_kl = (tf.lgamma(b_df / 2.0) - tf.lgamma(a_df / 2.0) +
(a_df - b_df) / 2.0 * tf.digamma(a_df / 2.0))

kl = tfd.kl_divergence(a, b)

x = a.sample(int(1e5), seed=0)
kl_sample = tf.reduce_mean(a.log_prob(x) - b.log_prob(x), axis=0)

true_kl_, kl_, kl_sample_ = self.evaluate([true_kl, kl, kl_sample])
self.assertAllClose(true_kl_, kl_, atol=0., rtol=1e-14)
self.assertAllClose(true_kl_, kl_sample_, atol=0., rtol=5e-2)

zero_kl = tfd.kl_divergence(a, a)
true_zero_kl_, zero_kl_ = self.evaluate([tf.zeros_like(zero_kl), zero_kl])
self.assertAllEqual(true_zero_kl_, zero_kl_)

if __name__ == "__main__":
tf.test.main()

0 comments on commit 661432e

Please sign in to comment.