Skip to content

Commit

Permalink
Ensure owens_t special function gracefully handles partially dynamic …
Browse files Browse the repository at this point in the history
…shapes.

ps.broadcast_shape is "all-or-nothing" in the presence of partially unknown shape. Owens T entails some while_loops, some of whose inputs were initially constants based on a ps.broadcast_shape. These were combined with other loop vars in the loop body, losing shape info along the way. Rewriting the broadcasts using tensor multiplication, while less "elegant", preserves the requisite shape info.

PiperOrigin-RevId: 413683996
  • Loading branch information
csuter authored and tensorflower-gardener committed Dec 2, 2021
1 parent b5f0348 commit c2299ff
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
26 changes: 16 additions & 10 deletions tensorflow_probability/python/math/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,8 +1385,6 @@ def series_evaluation(
should_stop = index >= m
return should_stop, index + 1., new_ai, new_di, new_gi, new_series_sum

broadcast_shape = prefer_static.broadcast_shape(
prefer_static.shape(h), prefer_static.shape(a))
initial_ai = a / numpy_dtype(2 * np.pi)
initial_di = tf.math.expm1(neg_half_h_squared)
initial_gi = neg_half_h_squared * tf.math.exp(neg_half_h_squared)
Expand All @@ -1397,7 +1395,12 @@ def series_evaluation(
cond=lambda stop, *_: tf.reduce_any(~stop),
body=series_evaluation,
loop_vars=(
tf.zeros(broadcast_shape, dtype=tf.bool),
# Use constant-tensor multiplication rather than static or dynamic
# shape broadcasting logic, since the former will be robust to
# partially-static shapes.
tf.cast(
tf.zeros_like(h) * tf.zeros_like(a),
dtype=tf.bool),
tf.cast(2., dtype=dtype),
initial_ai,
initial_di,
Expand Down Expand Up @@ -1430,8 +1433,6 @@ def series_evaluation(
should_stop = index >= num_iterations
return should_stop, index + 2., new_summand, new_term, new_series_sum

broadcast_shape = prefer_static.broadcast_shape(
prefer_static.shape(h), prefer_static.shape(a))
initial_summand = -0.5 * tf.math.erf(a * h) / h
initial_sum = initial_summand
initial_term = a * tf.math.exp(
Expand All @@ -1441,7 +1442,12 @@ def series_evaluation(
cond=lambda stop, *_: tf.reduce_any(~stop),
body=series_evaluation,
loop_vars=(
tf.zeros(broadcast_shape, dtype=tf.bool),
# Use constant-tensor multiplication rather than static or dynamic
# shape broadcasting logic, since the former will be robust to
# partially-static shapes.
tf.cast(
tf.zeros_like(h) * tf.zeros_like(a),
dtype=tf.bool),
tf.cast(1., dtype=dtype),
initial_summand,
initial_term,
Expand Down Expand Up @@ -1522,8 +1528,6 @@ def series_evaluation(
should_stop = index >= num_iterations
return should_stop, index + 2., new_term, new_coeff, new_series_sum

broadcast_shape = prefer_static.broadcast_shape(
prefer_static.shape(h), prefer_static.shape(a))
initial_term = a * tf.math.exp(
-0.5 * h_squared * (1 - nega_squared)) / (2 * np.pi)
initial_sum = initial_term
Expand All @@ -1532,10 +1536,12 @@ def series_evaluation(
cond=lambda stop, *_: tf.reduce_any(~stop),
body=series_evaluation,
loop_vars=(
tf.zeros(broadcast_shape, dtype=tf.bool),
tf.cast(
tf.zeros_like(h) * tf.zeros_like(a),
dtype=tf.bool),
tf.cast(3., dtype=dtype),
initial_term,
tf.ones(broadcast_shape, dtype=dtype),
tf.ones_like(h) * tf.ones_like(a),
initial_sum))
return series_sum

Expand Down
12 changes: 12 additions & 0 deletions tensorflow_probability/python/math/special_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import parameterized
import numpy as np
from scipy import special as scipy_special
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

Expand Down Expand Up @@ -462,6 +463,17 @@ def testOwensTGradient(self):
lambda x: tfp.math.owens_t(x, -a), [h])
self.assertLess(err, 2e-4)

@test_util.disable_test_for_backend(
disable_numpy=True, disable_jax=True,
reason="Only relevant for dynamic shapes in TF.")
def testOwensPartiallyKnownShape(self):
h = tf1.placeholder_with_default(np.array([1.]).reshape([1, 1]),
shape=(None, 1))
a = tf1.placeholder_with_default(np.array([1.]).reshape([1, 1]),
shape=(None, 1))
# We simply verify that this runs without an Exception.
_ = tfp.math.owens_t(h, a)


@test_util.test_graph_and_eager_modes
class SpecialTest(test_util.TestCase):
Expand Down

0 comments on commit c2299ff

Please sign in to comment.