Skip to content

Commit

Permalink
Support tf.eager in tfp.mcmc.HamiltonianMonteCarlo.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 193227343
  • Loading branch information
Joshua V. Dillon authored and Copybara-Service committed Apr 17, 2018
1 parent cdd1291 commit e3771a5
Show file tree
Hide file tree
Showing 4 changed files with 496 additions and 488 deletions.
27 changes: 21 additions & 6 deletions tensorflow_probability/python/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensorflow_probability.python.mcmc import kernel as kernel_base
from tensorflow_probability.python.mcmc import metropolis_hastings
from tensorflow_probability.python.mcmc import util as mcmc_util
from tensorflow.contrib import eager as tfe
from tensorflow.python.ops.distributions import util as distributions_util


Expand Down Expand Up @@ -503,9 +504,10 @@ def bootstrap_results(self, init_state):
if not mcmc_util.is_list_like(init_state):
init_state = [init_state]
init_state = [tf.convert_to_tensor(x) for x in init_state]
init_target_log_prob = self.target_log_prob_fn(*init_state)
init_grads_target_log_prob = tf.gradients(
init_target_log_prob, init_state)
[
init_target_log_prob,
init_grads_target_log_prob,
] = _value_and_gradients(self.target_log_prob_fn, *init_state)
return UncalibratedHamiltonianMonteCarloKernelResults(
log_acceptance_correction=tf.zeros_like(init_target_log_prob),
target_log_prob=init_target_log_prob,
Expand Down Expand Up @@ -723,12 +725,16 @@ def _leapfrog_step(current_momentums,
in zip(current_state_parts,
step_sizes,
proposed_momentums)]
proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts)

[
proposed_target_log_prob,
proposed_grads_target_log_prob,
] = _value_and_gradients(target_log_prob_fn, *proposed_state_parts)

if not proposed_target_log_prob.dtype.is_floating:
raise TypeError('`target_log_prob_fn` must produce a `Tensor` '
'with `float` `dtype`.')
proposed_grads_target_log_prob = tf.gradients(
proposed_target_log_prob, proposed_state_parts)

if any(g is None for g in proposed_grads_target_log_prob):
raise ValueError(
'Encountered `None` gradient. Does your target `target_log_prob_fn` '
Expand Down Expand Up @@ -905,3 +911,12 @@ def maybe_flatten(x):
def _log_sum_sq(x, axis=None):
"""Computes log(sum(x**2))."""
return tf.reduce_logsumexp(2. * tf.log(tf.abs(x)), axis)


def _value_and_gradients(fn, *args):
"""Calls `fn` and computes the gradient of the result wrt `args_list`."""
if tfe.executing_eagerly():
return tfe.value_and_gradients_function(fn)(*args)
result = fn(*args)
grads = tf.gradients(result, args)
return result, grads
Loading

0 comments on commit e3771a5

Please sign in to comment.