Skip to content

Commit

Permalink
Fix flaky test
Browse files Browse the repository at this point in the history
- Flakyness in optimizer is due to maxcor too low
- Flakyness in linear regression test is due to incorrect logprob function (actually all algorithm in the same test returns bias result)
  • Loading branch information
junpenglao authored and rlouf committed Sep 3, 2022
1 parent 2345416 commit 72f1952
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
3 changes: 0 additions & 3 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,9 +1080,6 @@ def one_step(carry, rng_key):
return ((state, adaptation_state), (state, info, adaptation_state.da_state))

def run(rng_key: PRNGKey, position: PyTree):

rng_key_init, rng_key_chain = jax.random.split(rng_key, 2)

init_warmup_state, init_position = init(rng_key, position, initial_step_size)
init_state = algorithm.init(init_position, logprob_fn)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_dual_averaging(self):

@chex.all_variants(with_pmap=False)
@parameterized.parameters(
[(5, 10), (10, 1), (10, 20)],
[(5, 10), (10, 2), (10, 20)],
)
def test_minimize_lbfgs(self, maxiter, maxcor):
"""Test if dot product between approximate inverse hessian and gradient is
Expand Down
10 changes: 5 additions & 5 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def setUp(self):

def regression_logprob(self, scale, coefs, preds, x):
"""Linear regression"""
logpdf = 0
logpdf += stats.expon.logpdf(scale, 1, 1)
logpdf += stats.norm.logpdf(coefs, 3 * jnp.ones(x.shape[-1]), 2)
scale_prior = stats.expon.logpdf(scale, 1, 1)
coefs_prior = stats.norm.logpdf(coefs, 0, 5)
y = jnp.dot(x, coefs)
logpdf += stats.norm.logpdf(preds, y, scale)
return jnp.sum(logpdf)
logpdf = stats.norm.logpdf(preds, y, scale)
# reduce sum otherwise broacasting will make the logprob biased.
return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf])

@parameterized.parameters(itertools.product(regression_test_cases, [True, False]))
def test_window_adaptation(self, case, is_mass_matrix_diagonal):
Expand Down

0 comments on commit 72f1952

Please sign in to comment.