Skip to content

Commit

Permalink
Fix RationalQuadraticSpline test for JAX by avoiding a vmap+jacobian …
Browse files Browse the repository at this point in the history
…interaction not yet well supported by tfp.jax.

PiperOrigin-RevId: 318885402
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Jun 29, 2020
1 parent 257e614 commit c9c479e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
1 change: 0 additions & 1 deletion tensorflow_probability/python/bijectors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,6 @@ multi_substrate_py_test(
name = "rational_quadratic_spline_test",
size = "medium",
srcs = ["rational_quadratic_spline_test.py"],
jax_tags = ["notap"],
numpy_tags = ["notap"],
tags = ["hypothesis"],
deps = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import functools

from absl import logging
import hypothesis as hp
from hypothesis import strategies as hps
import numpy as np
Expand All @@ -33,6 +32,9 @@
from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
from tensorflow_probability.python.internal import test_util


JAX_MODE = False

# pylint: disable=no-value-for-parameter


Expand All @@ -44,6 +46,7 @@ def rq_splines(draw, batch_shape=None, dtype=tf.float32):
lo = draw(hps.floats(min_value=-5, max_value=.5))
hi = draw(hps.floats(min_value=-.5, max_value=5))
lo, hi = min(lo, hi), max(lo, hi) + .2
hp.note('lo, hi: {!r}'.format((lo, hi)))

constraints = dict(
bin_widths=functools.partial(
Expand All @@ -57,6 +60,7 @@ def rq_splines(draw, batch_shape=None, dtype=tf.float32):
batch_shape,
params_event_ndims=dict(bin_widths=1, bin_heights=1, knot_slopes=1),
constraint_fn_for=constraints.get))
hp.note('params: {!r}'.format(params))
return tfb.RationalQuadraticSpline(
range_min=lo, validate_args=draw(hps.booleans()), **params)

Expand Down Expand Up @@ -165,22 +169,32 @@ def testTheoreticalFldjSimple(self):
atol=1e-5,
rtol=1e-5)

@test_util.numpy_disable_gradient_test
@hp.given(hps.data())
@tfp_hps.tfp_hp_settings(default_max_examples=5)
def testTheoreticalFldj(self, data):
# get_fldj_theoretical test rig requires 1-d batches.
batch_shape = data.draw(tfp_hps.shapes(min_ndims=1, max_ndims=1))
if JAX_MODE: # TODO(b/160167257): Eliminate this workaround.
# get_fldj_theoretical uses tfp.math.batch_jacobian and assumes the
# behavior of the bijector does not vary by position. In this case, it
# can, so we must vmap the result.
batch_shape = [1]
else:
# get_fldj_theoretical test rig requires 1-d batches.
batch_shape = data.draw(tfp_hps.shapes(min_ndims=1, max_ndims=1))
hp.note('batch shape: {}'.format(batch_shape))
bijector = data.draw(rq_splines(batch_shape=batch_shape, dtype=tf.float64))
self.assertEqual(tf.float64, bijector.dtype)
bw, bh, kd = self.evaluate(
[bijector.bin_widths, bijector.bin_heights, bijector.knot_slopes])
logging.info('bw: %s\nbh: %s\nkd: %s', bw, bh, kd)
hp.note('bw: {!r}\nbh: {!r}\nkd: {!r}'.format(bw, bh, kd))
x_shp = ((bw + bh)[..., :-1] + kd).shape[:-1]
if x_shp[-1] == 1: # Possibly broadcast the x dim.
dim = data.draw(hps.integers(min_value=1, max_value=7))
x_shp = x_shp[:-1] + (dim,)
x = np.linspace(-5, 5, np.prod(x_shp), dtype=np.float64).reshape(*x_shp)
x = np.linspace(-4.9, 4.9, np.prod(x_shp), dtype=np.float64).reshape(*x_shp)
hp.note('x: {!r}'.format(x))
y = self.evaluate(bijector.forward(x))
hp.note('x: {!r}'.format(x))
bijector_test_util.assert_bijective_and_finite(
bijector,
x,
Expand Down

0 comments on commit c9c479e

Please sign in to comment.