Skip to content

Commit

Permalink
Enable JIT-ting windowed sampling in JAX backend
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 382083866
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Jun 29, 2021
1 parent caa5329 commit 78b3611
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def _windowed_adaptive_impl(n_draws,
tf1.get_default_graph())):
# A Tensor num_draws argument breaks XLA, which requires static TensorArray
# trace_fn result allocation sizes.
num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)
num_adaptation_steps = ps.convert_to_shape_tensor(num_adaptation_steps)

dual_averaging_kwargs.setdefault('reduce_fn',
generic_math.reduce_log_harmonic_mean_exp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.python.internal import unnest

JAX_MODE = False

tfb = tfp.bijectors
tfd = tfp.distributions
Expand Down Expand Up @@ -597,19 +598,18 @@ def get_joint_distribution(
name='jd')


@test_util.disable_test_for_backend(disable_jax=True,
reason='Only applies to TF')
class PrecompiledTest(test_util.TestCase):

def setUp(self):
super().setUp()
arms = 2
days = 3

strm = test_util.test_seed_stream()
self.trials = tfd.Poisson(100.).sample([arms, days], seed=strm())
seed = test_util.test_seed()
trial_seed, value_seed = tfp.random.split_seed(seed)
self.trials = tfd.Poisson(100.).sample([arms, days], seed=trial_seed)
dist = get_joint_distribution(self.trials)
self.true_values = dist.sample(seed=strm())
self.true_values = dist.sample(seed=value_seed)

def nuts_kwargs(self):
return {'max_tree_depth': 2}
Expand All @@ -622,13 +622,16 @@ def hmc_kwargs(self):
def test_base_kernel(self, kind):
self.skip_if_no_xla()

input_signature = (
tf.TensorSpec(
shape=[None, None], dtype=tf.float32, name='trials'),
tf.TensorSpec(
shape=[None, None], dtype=tf.float32, name='successes'),
tf.TensorSpec(
shape=[2], dtype=tf.int32, name='seed'))
if JAX_MODE:
input_signature = None
else:
input_signature = (
tf.TensorSpec(
shape=[None, None], dtype=tf.float32, name='trials'),
tf.TensorSpec(
shape=[None, None], dtype=tf.float32, name='successes'),
tf.TensorSpec(
shape=[2], dtype=tf.int32, name='seed'))
@tf.function(jit_compile=True, input_signature=input_signature)
def do(trials, successes, seed):
if kind == 'hmc':
Expand Down

0 comments on commit 78b3611

Please sign in to comment.