Skip to content

Commit

Permalink
Fix omnistaging bug in AIS
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 434804127
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Mar 15, 2022
1 parent 1b7edc6 commit 7ff8499
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/mcmc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ multi_substrate_py_library(
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:samplers",
"//tensorflow_probability/python/mcmc/internal",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.mcmc.internal import util as mcmc_util

Expand Down Expand Up @@ -258,7 +259,7 @@ def _bootstrap_results(init_state):

convex_combined_log_prob = mh_results.accepted_results.target_log_prob
dtype = dtype_util.as_numpy_dtype(convex_combined_log_prob.dtype)
shape = tf.shape(convex_combined_log_prob)
shape = ps.shape(convex_combined_log_prob)
proposal_log_prob = tf.fill(shape, dtype(np.nan),
name='bootstrap_proposal_log_prob')
target_log_prob = tf.fill(shape, dtype(np.nan),
Expand All @@ -275,9 +276,9 @@ def _bootstrap_results(init_state):
mh_results = _find_inner_mh_results(inner_results)

ais_weights = tf.zeros(
shape=tf.broadcast_dynamic_shape(
tf.shape(mh_results.proposed_results.target_log_prob),
tf.shape(mh_results.accepted_results.target_log_prob)),
shape=ps.broadcast_shape(
ps.shape(mh_results.proposed_results.target_log_prob),
ps.shape(mh_results.accepted_results.target_log_prob)),
dtype=mh_results.proposed_results.target_log_prob.dtype)

[_, _, ais_weights, current_state, kernel_results] = tf.while_loop(
Expand Down

0 comments on commit 7ff8499

Please sign in to comment.