Skip to content

Commit

Permalink
Fix convergence test for NUTS.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 267434207
  • Loading branch information
junpenglao authored and tensorflower-gardener committed Sep 5, 2019
1 parent 12ab889 commit 6c85e01
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions tensorflow_probability/python/mcmc/nuts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

@tf.function(autograph=False)
def run_nuts_chain(event_size, batch_size, num_steps, initial_state=None):
strm = tfp_test_util.test_seed_stream()
def target_log_prob_fn(event):
with tf.name_scope('nuts_test_target_log_prob'):
return tfd.MultivariateNormalDiag(
Expand All @@ -51,7 +52,7 @@ def target_log_prob_fn(event):
step_size=[0.3],
unrolled_leapfrog_steps=2,
max_tree_depth=4,
seed=1)
seed=strm())

chain_state, leapfrogs_taken = tfp.mcmc.sample_chain(
num_results=num_steps,
Expand Down Expand Up @@ -110,10 +111,11 @@ def run_chain():


def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
strm = tfp_test_util.test_seed_stream()
initialization = tfd.MultivariateNormalFullCovariance(
loc=tf.zeros(event_size),
covariance_matrix=tf.eye(event_size)).sample(
batch_size, seed=4)
batch_size, seed=strm())
samples, _ = run_nuts_chain(
event_size, batch_size, num_steps=1,
initial_state=initialization, **kwargs)
Expand Down Expand Up @@ -284,8 +286,10 @@ def trace_fn(_, pkr):
np.any(np.isin(np.asarray([5, 9, 11, 13]), np.unique(leapfrogs_taken))))

def testCorrelated2dNormalwithinMCError(self):
strm = tfp_test_util.test_seed_stream()
nchains, num_steps = 10, 500
# Test is flaky with 500 samples, set specific seed here.
strm = tfd.SeedStream(1, salt='Correlated2dNormalwithinMCError')
nchains = 100
num_steps = 500
mu = np.asarray([0., 3.], dtype=np.float32)
rho = 0.75
sigma1 = 1.
Expand Down Expand Up @@ -325,11 +329,14 @@ def run_chain_and_get_estimation_error():
scaled_error = (
tf.abs(expected - true_param) / avg_monte_carlo_standard_error)

return tfd.Normal(loc=0., scale=1.).prob(scaled_error)
return tfd.Normal(loc=0., scale=1.).survival_function(scaled_error)

# Probability of getting this error
# Probability of getting an error more extreme than this
error_prob = self.evaluate(run_chain_and_get_estimation_error())
self.assertAllGreater(error_prob, 0.01)

# Check convergence using Markov chain central limit theorem, this is a
# z-test at p=.025
self.assertAllGreater(error_prob, 0.0125)

@parameterized.parameters(
(7, 5, 3, None),
Expand All @@ -354,7 +361,7 @@ def testDynamicShape(self, nsample, batch_size, nd, dynamic_shape):

def testDivergence(self):
"""Neals funnel with large step size."""
strm = tfd.SeedStream(1, salt='DivergenceTest')
strm = tfp_test_util.test_seed_stream()
neals_funnel = tfd.JointDistributionSequential(
[
tfd.Normal(loc=0., scale=3.), # b0
Expand Down

0 comments on commit 6c85e01

Please sign in to comment.