Skip to content

Commit

Permalink
fix operations with pytrees
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf authored and Rémi Louf committed Jan 14, 2021
1 parent 30a82cc commit facc554
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion blackjax/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def kernel(potential_fn: Callable, parameters: HMCParameters) -> Callable:
"""
step_size, num_integration_steps, inv_mass_matrix, divergence_threshold = parameters

if not inv_mass_matrix:
if inv_mass_matrix is None:
raise ValueError(
"Expected a value for `inv_mass_matrix`,"
" got None. Please specify a value when initializing"
Expand Down
2 changes: 1 addition & 1 deletion blackjax/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def kernel(
)
new_position, new_momentum, new_potential_energy_grad = proposal

flipped_momentum = -1.0 * new_momentum
flipped_momentum = jax.tree_util.tree_multimap(lambda m: -1.0 * m, new_momentum)
new_potential_energy = potential_fn(new_position)
new_energy = new_potential_energy + kinetic_energy(
flipped_momentum, new_position
Expand Down
2 changes: 2 additions & 0 deletions blackjax/inference/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def momentum_generator(rng_key: jax.random.PRNGKey, position: PyTree) -> PyTree:

def kinetic_energy(momentum: PyTree, *_) -> float:
momentum, _ = tree_flatten(momentum)
momentum = jnp.array(momentum)
velocity = jnp.multiply(inverse_mass_matrix, momentum)
return 0.5 * jnp.dot(velocity, momentum)

Expand All @@ -85,6 +86,7 @@ def momentum_generator(rng_key: jax.random.PRNGKey, position: PyTree) -> PyTree:

def kinetic_energy(momentum: PyTree, *_) -> float:
momentum, _ = tree_flatten(momentum)
momentum = jnp.array(momentum)
velocity = jnp.matmul(inverse_mass_matrix, momentum)
return 0.5 * jnp.dot(velocity, momentum)

Expand Down

0 comments on commit facc554

Please sign in to comment.