Skip to content

Commit

Permalink
TruncatedNormal modified to use tf.random.stateless_parameterized_tru…
Browse files Browse the repository at this point in the history
…ncated_normal, and jax/scipy implementations. This eliminates the need for custom gradients, since both TF and JAX have derivatives.

PiperOrigin-RevId: 322881803
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Jul 23, 2020
1 parent ffcf1cb commit d40b2fc
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
JVP_SAMPLE_BLOCKLIST = (
'Gamma',
'GeneralizedNormal',
'TruncatedNormal',
'VonMises',
)
JVP_LOGPROB_SAMPLE_BLOCKLIST = ()
Expand Down
91 changes: 9 additions & 82 deletions tensorflow_probability/python/distributions/truncated_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import special_math
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.math.generic import log_sub_exp as _log_sub_exp
from tensorflow.python.ops import random_ops # pylint: disable=g-direct-tensorflow-import


__all__ = [
Expand Down Expand Up @@ -248,91 +248,18 @@ def _event_shape(self):
return tf.TensorShape([])

def _sample_n(self, n, seed=None):
seed = samplers.sanitize_seed(seed)
loc, scale, low, high = self._loc_scale_low_high()
batch_shape = self._batch_shape_tensor(
loc=loc, scale=scale, low=low, high=high)
sample_and_batch_shape = tf.concat([[n], batch_shape], 0)
flat_batch_and_sample_shape = tf.stack([tf.reduce_prod(batch_shape), n])

# In order to be reparameterizable we sample on the truncated_normal of
# unit variance and mean and scale (but with the standardized
# truncation bounds).

@tf.custom_gradient
def _std_samples_with_gradients(lower, upper):
"""Standard truncated Normal with gradient support for low, high."""
# Note: Unlike the convention in TFP, parameterized_truncated_normal
# returns a tensor with the final dimension being the sample dimension.
std_samples = random_ops.parameterized_truncated_normal(
shape=flat_batch_and_sample_shape,
means=0.0,
stddevs=1.0,
minvals=lower,
maxvals=upper,
dtype=self.dtype,
seed=seed)

def grad(dy):
"""Computes a derivative for the min and max parameters.
This function implements the derivative wrt the truncation bounds, which
get blocked by the sampler. We use a custom expression for numerical
stability instead of automatic differentiation on CDF for implicit
gradients.
Args:
dy: output gradients
Returns:
The standard normal samples and the gradients wrt the upper
bound and lower bound.
"""
# std_samples has an extra dimension (the sample dimension), expand
# lower and upper so they broadcast along this dimension.
# See note above regarding parameterized_truncated_normal, the sample
# dimension is the final dimension.
lower_broadcast = lower[..., tf.newaxis]
upper_broadcast = upper[..., tf.newaxis]

cdf_samples = ((special_math.ndtr(std_samples) -
special_math.ndtr(lower_broadcast)) /
(special_math.ndtr(upper_broadcast) -
special_math.ndtr(lower_broadcast)))

# tiny, eps are tolerance parameters to ensure we stay away from giving
# a zero arg to the log CDF expression.

tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
tf.math.log(cdf_samples))
dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +
tf.math.log1p(-cdf_samples))

# Reduce the gradient across the samples
grad_u = tf.reduce_sum(dy * du, axis=-1)
grad_l = tf.reduce_sum(dy * dl, axis=-1)
return [grad_l, grad_u]

return std_samples, grad

std_low, std_high = self._standardized_low_and_high(
low=low, high=high, loc=loc, scale=scale)
low_high_shp = tf.broadcast_dynamic_shape(
tf.shape(std_low), tf.shape(std_high))
std_low = tf.broadcast_to(std_low, low_high_shp)
std_high = tf.broadcast_to(std_high, low_high_shp)

std_samples = _std_samples_with_gradients(
tf.reshape(std_low, [-1]), tf.reshape(std_high, [-1]))

# The returned shape is [flat_batch x n]
std_samples = tf.transpose(std_samples, perm=[1, 0])

std_samples = tf.reshape(std_samples, sample_and_batch_shape)
return std_samples * scale[tf.newaxis] + loc[tf.newaxis]
return tf.random.stateless_parameterized_truncated_normal(
shape=sample_and_batch_shape,
means=loc,
stddevs=scale,
minvals=low,
maxvals=high,
seed=seed)

def _log_prob(self, x):
loc, scale, low, high = self._loc_scale_low_high()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'gamma',
'stateless_gamma',
'stateless_normal',
'stateless_parameterized_truncated_normal',
'stateless_poisson',
'stateless_shuffle',
'stateless_uniform',
Expand Down Expand Up @@ -222,6 +223,28 @@ def _shuffle_jax(value, seed=None, name=None): # pylint: disable=unused-argumen
return jaxrand.shuffle(seed, value, axis=0)


def _truncated_normal(
shape, seed, means=0.0, stddevs=1.0, minvals=-2.0, maxvals=2.0, name=None): # pylint: disable=unused-argument
from scipy import stats # pylint: disable=g-import-not-at-top
rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
std_low = (minvals - means) / stddevs
std_high = (maxvals - means) / stddevs
std_samps = stats.truncnorm.rvs(
std_low, std_high, size=shape, random_state=rng)
return std_samps * stddevs + means


def _truncated_normal_jax(
shape, seed, means=0.0, stddevs=1.0, minvals=-2.0, maxvals=2.0, name=None): # pylint: disable=unused-argument
import jax.random as jaxrand # pylint: disable=g-import-not-at-top
if seed is None:
raise ValueError('Must provide PRNGKey to sample in JAX.')
std_low = (minvals - means) / stddevs
std_high = (maxvals - means) / stddevs
std_samps = jaxrand.truncated_normal(seed, std_low, std_high, shape)
return std_samps * stddevs + means


def _uniform(shape, minval=0, maxval=None, dtype=np.float32, seed=None,
name=None): # pylint: disable=unused-argument
"""Numpy uniform random sampler."""
Expand Down Expand Up @@ -305,6 +328,10 @@ def gamma(shape, alpha, beta=None, dtype=np.float32, seed=None, name=None):
'tf.random.normal',
_normal_jax if JAX_MODE else _normal)

stateless_parameterized_truncated_normal = utils.copy_docstring(
'tf.random.stateless_parameterized_truncated_normal',
_truncated_normal_jax if JAX_MODE else _truncated_normal)

stateless_poisson = utils.copy_docstring(
'tf.random.poisson',
_poisson_jax if JAX_MODE else _poisson)
Expand Down

0 comments on commit d40b2fc

Please sign in to comment.