Skip to content

Commit

Permalink
Return non-cumulated leapfrogs_taken in nuts kernel_result.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 272003178
  • Loading branch information
junpenglao authored and tensorflower-gardener committed Sep 30, 2019
1 parent 389557d commit 24d4a46
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 6 deletions.
3 changes: 0 additions & 3 deletions tensorflow_probability/python/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,7 @@ def _copy(v):
new_step_metastate.energy_diff_sum /
tf.cast(new_step_metastate.leapfrog_count,
dtype=new_step_metastate.energy_diff_sum.dtype)),
# TODO(junpenglao): return non-cumulated leapfrogs_taken once
# benchmarking is done.
leapfrogs_taken=(
previous_kernel_results.leapfrogs_taken +
new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps
),
is_accepted=new_step_metastate.is_accepted,
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/mcmc/nuts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,12 @@ def trace_fn(_, pkr):
trace_fn=trace_fn,
parallel_iterations=1)

leapfrogs_taken_ = leapfrogs_taken[1:] - leapfrogs_taken[:-1]

return (
tf.shape(x),
# We'll average over samples (dim=0) and chains (dim=1).
tf.reduce_mean(x, axis=[0, 1]),
tfp.stats.covariance(x, sample_axis=[0, 1]),
leapfrogs_taken_[is_accepted[1:]])
leapfrogs_taken[is_accepted])

sample_shape, sample_mean, sample_cov, leapfrogs_taken = self.evaluate(
run_chain_and_get_summary(
Expand Down

0 comments on commit 24d4a46

Please sign in to comment.