Skip to content

Commit

Permalink
TruncatedNormal sampling in XLA was broken by the change to stateless.
Browse files Browse the repository at this point in the history
This partially rolls back that change, in cases where the graph will run under XLA.

PiperOrigin-RevId: 324071396
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Jul 30, 2020
1 parent 6f798ff commit 10b821f
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 8 deletions.
1 change: 1 addition & 0 deletions tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3129,6 +3129,7 @@ multi_substrate_py_test(
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:test_util",
# tensorflow/compiler/jit dep,
],
)

Expand Down
104 changes: 96 additions & 8 deletions tensorflow_probability/python/distributions/truncated_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# Dependency imports
import numpy as np

import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.bijectors import sigmoid as sigmoid_bijector
Expand All @@ -35,6 +36,8 @@
from tensorflow_probability.python.internal import special_math
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.math.generic import log_sub_exp as _log_sub_exp
from tensorflow.python.ops import control_flow_util # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import random_ops # pylint: disable=g-direct-tensorflow-import


__all__ = [
Expand Down Expand Up @@ -248,18 +251,103 @@ def _event_shape(self):
return tf.TensorShape([])

def _sample_n(self, n, seed=None):
seed = samplers.sanitize_seed(seed)
loc, scale, low, high = self._loc_scale_low_high()
batch_shape = self._batch_shape_tensor(
loc=loc, scale=scale, low=low, high=high)
sample_and_batch_shape = tf.concat([[n], batch_shape], 0)
return tf.random.stateless_parameterized_truncated_normal(
shape=sample_and_batch_shape,
means=loc,
stddevs=scale,
minvals=low,
maxvals=high,
seed=seed)
flat_batch_and_sample_shape = tf.stack([tf.reduce_prod(batch_shape), n])

# TODO(b/162522020): Use this behavior unconditionally.
if (tf.executing_eagerly() or
not control_flow_util.GraphOrParentsInXlaContext(
tf1.get_default_graph())):
return tf.random.stateless_parameterized_truncated_normal(
shape=sample_and_batch_shape,
means=loc,
stddevs=scale,
minvals=low,
maxvals=high,
seed=samplers.sanitize_seed(seed))

# In order to be reparameterizable we sample on the truncated_normal of
# unit variance and mean and scale (but with the standardized
# truncation bounds).

@tf.custom_gradient
def _std_samples_with_gradients(lower, upper):
"""Standard truncated Normal with gradient support for low, high."""
# Note: Unlike the convention in TFP, parameterized_truncated_normal
# returns a tensor with the final dimension being the sample dimension.
std_samples = random_ops.parameterized_truncated_normal(
shape=flat_batch_and_sample_shape,
means=0.0,
stddevs=1.0,
minvals=lower,
maxvals=upper,
dtype=self.dtype,
seed=seed)

def grad(dy):
"""Computes a derivative for the min and max parameters.
This function implements the derivative wrt the truncation bounds, which
get blocked by the sampler. We use a custom expression for numerical
stability instead of automatic differentiation on CDF for implicit
gradients.
Args:
dy: output gradients
Returns:
The standard normal samples and the gradients wrt the upper
bound and lower bound.
"""
# std_samples has an extra dimension (the sample dimension), expand
# lower and upper so they broadcast along this dimension.
# See note above regarding parameterized_truncated_normal, the sample
# dimension is the final dimension.
lower_broadcast = lower[..., tf.newaxis]
upper_broadcast = upper[..., tf.newaxis]

cdf_samples = ((special_math.ndtr(std_samples) -
special_math.ndtr(lower_broadcast)) /
(special_math.ndtr(upper_broadcast) -
special_math.ndtr(lower_broadcast)))

# tiny, eps are tolerance parameters to ensure we stay away from giving
# a zero arg to the log CDF expression.

tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
tf.math.log(cdf_samples))
dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +
tf.math.log1p(-cdf_samples))

# Reduce the gradient across the samples
grad_u = tf.reduce_sum(dy * du, axis=-1)
grad_l = tf.reduce_sum(dy * dl, axis=-1)
return [grad_l, grad_u]

return std_samples, grad

std_low, std_high = self._standardized_low_and_high(
low=low, high=high, loc=loc, scale=scale)
low_high_shp = tf.broadcast_dynamic_shape(
tf.shape(std_low), tf.shape(std_high))
std_low = tf.broadcast_to(std_low, low_high_shp)
std_high = tf.broadcast_to(std_high, low_high_shp)

std_samples = _std_samples_with_gradients(
tf.reshape(std_low, [-1]), tf.reshape(std_high, [-1]))

# The returned shape is [flat_batch x n]
std_samples = tf.transpose(std_samples, perm=[1, 0])

std_samples = tf.reshape(std_samples, sample_and_batch_shape)
return std_samples * scale[tf.newaxis] + loc[tf.newaxis]

def _log_prob(self, x):
loc, scale, low, high = self._loc_scale_low_high()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,15 @@ def testSupportBijectorOutsideRange(self):
).inverse(x)
self.assertAllNan(self.evaluate(bijector_inverse_x))

def testSampleXLA(self):
self.skip_if_no_xla()
@tf.function(experimental_compile=True)
def f(loc):
return tfd.TruncatedNormal(
loc=loc, scale=1., low=-1., high=1.).sample(
[3], seed=test_util.test_seed())
self.evaluate(f(tf.constant(0.2)))


# TODO(b/150161911): reconcile graph- and eager-mode handling of denormal floats
# so that we can re-enable eager mode tests.
Expand Down

0 comments on commit 10b821f

Please sign in to comment.