Skip to content

Commit

Permalink
Return the whole path in PathfinderInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 12, 2022
1 parent 5670ff4 commit 1ef2544
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 25 deletions.
4 changes: 2 additions & 2 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def approximate_fn(
**lbfgs_parameters,
):
return cls.approximate(
rng_key, logprob_fn, position, num_samples, False, **lbfgs_parameters
rng_key, logprob_fn, position, num_samples, **lbfgs_parameters
)

def sample_fn(
Expand Down Expand Up @@ -1316,7 +1316,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400):

init_key, sample_key, rng_key = jax.random.split(rng_key, 3)

pathfinder_state = vi.pathfinder.approximate(init_key, logprob_fn, position)
pathfinder_state, _ = vi.pathfinder.approximate(init_key, logprob_fn, position)
init_warmup_state = init(
pathfinder_state.alpha,
pathfinder_state.beta,
Expand Down
44 changes: 22 additions & 22 deletions blackjax/vi/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class PathfinderState(NamedTuple):
resulting ELBO and all factors needed to sample from the approximated
target density.
elbo:
ELBO of approximation wrt target distribution
position:
position
grad_position:
gradient of target distribution wrt position
alpha, beta, gamma:
factored rappresentation of the inverse hessian
elbo:
ELBO of approximation wrt target distribution
"""

elbo: Array
Expand All @@ -43,14 +43,19 @@ class PathfinderState(NamedTuple):
gamma: Array


class PathfinderInfo(NamedTuple):
"""Extra information returned by the Pathfinder algorithm."""

path: PathfinderState


def approximate(
rng_key: PRNGKey,
logprob_fn: Callable,
initial_position: PyTree,
num_samples: int = 200,
return_path: bool = False,
**lbfgs_parameters
) -> PathfinderState:
) -> Tuple[PathfinderState, PathfinderInfo]:
"""
Pathfinder variational inference algorithm:
pathfinder locates normal approximations to the target density along a
Expand All @@ -72,9 +77,6 @@ def approximate(
starting point of the L-BFGS optimization routine
num_samples
number of samples to draw to estimate ELBO
return_path
if False output only iteration that maximize ELBO, otherwise output
all iterations
lbfgs_parameters:
Parameters passed to the internal call to `lbfgs_minimize`. The
following parameters are available:
Expand All @@ -86,10 +88,9 @@ def approximate(
Returns
-------
if return_path=True a PathfinderState with full information
on the optimization path
if return_path=False a PathfinderState with information on the iteration
in the optimization path whose approximate samples yields the highest ELBO
A PathfinderState with information on the iteration in the optimization path
whose approximate samples yields the highest ELBO, and PathfinderInfo that
contains all the states traversed.
References
----------
Expand Down Expand Up @@ -176,19 +177,18 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad):

unravel_fn_mapped = jax.vmap(unravel_fn)
pathfinder_result = PathfinderState(
elbo=elbo,
position=unravel_fn_mapped(position),
grad_position=unravel_fn_mapped(grad_position),
alpha=alpha,
beta=beta,
gamma=gamma,
elbo,
unravel_fn_mapped(position),
unravel_fn_mapped(grad_position),
alpha,
beta,
gamma,
)

if return_path:
return pathfinder_result
else:
best_i = jnp.argmax(elbo)
return jax.tree_map(lambda x: x[best_i], pathfinder_result)
max_elbo_idx = jnp.argmax(elbo)
return jax.tree_map(lambda x: x[max_elbo_idx], pathfinder_result), PathfinderInfo(
pathfinder_result
)


def sample(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov):

x0 = jnp.ones(ndim)
pathfinder = blackjax.pathfinder(logp_model)
out = self.variant(pathfinder.approximate)(rng_key_pathfinder, x0)
out, _ = self.variant(pathfinder.approximate)(rng_key_pathfinder, x0)

sim_p, log_p = bfgs_sample(
rng_key_pathfinder,
Expand Down

0 comments on commit 1ef2544

Please sign in to comment.