Skip to content

Commit

Permalink
Make Poisson Lognormal distribution TF2 tape-safe.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 260014556
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Jul 25, 2019
1 parent 7a77df7 commit 1fdd1ff
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 63 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,9 @@ py_library(
"//tensorflow_probability/python/bijectors:exp",
"//tensorflow_probability/python/internal:distribution_util",
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:tensor_util",
"//tensorflow_probability/python/internal:tensorshape_util",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
'OneHotCategorical',
'Pareto',
'Poisson',
# 'PoissonLogNormalQuadratureCompound' TODO(b/137956955): Add support
# for hypothesis testing
'ProbitBernoulli',
'StudentT',
'Triangular',
Expand Down
157 changes: 95 additions & 62 deletions tensorflow_probability/python/distributions/poisson_lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@
from tensorflow_probability.python.distributions import transformed_distribution
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import


__all__ = [
"PoissonLogNormalQuadratureCompound",
"quadrature_scheme_lognormal_gauss_hermite",
"quadrature_scheme_lognormal_quantiles",
'PoissonLogNormalQuadratureCompound',
'quadrature_scheme_lognormal_gauss_hermite',
'quadrature_scheme_lognormal_quantiles',
]


Expand Down Expand Up @@ -70,13 +73,13 @@ def quadrature_scheme_lognormal_gauss_hermite(
weight associate with each `grid` value.
"""
with tf.name_scope(
name or "vector_diffeomixture_quadrature_gauss_hermite"):
name or 'vector_diffeomixture_quadrature_gauss_hermite'):
grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size)
npdt = dtype_util.as_numpy_dtype(loc.dtype)
grid = grid.astype(npdt)
probs = probs.astype(npdt)
probs /= np.linalg.norm(probs, ord=1, keepdims=True)
probs = tf.convert_to_tensor(probs, name="probs", dtype=loc.dtype)
probs = tf.convert_to_tensor(probs, name='probs', dtype=loc.dtype)
# The following maps the broadcast of `loc` and `scale` to each grid
# point, i.e., we are creating several log-rates that correspond to the
# different Gauss-Hermite quadrature points and (possible) batches of
Expand Down Expand Up @@ -109,7 +112,7 @@ def quadrature_scheme_lognormal_quantiles(
probs: (Batch of) length-`quadrature_size` vectors representing the
weight associate with each `grid` value.
"""
with tf.name_scope(name or "quadrature_scheme_lognormal_quantiles"):
with tf.name_scope(name or 'quadrature_scheme_lognormal_quantiles'):
# Create a LogNormal distribution.
dist = transformed_distribution.TransformedDistribution(
distribution=normal.Normal(loc=loc, scale=scale),
Expand Down Expand Up @@ -222,7 +225,7 @@ def __init__(self,
quadrature_fn=quadrature_scheme_lognormal_quantiles,
validate_args=False,
allow_nan_stats=True,
name="PoissonLogNormalQuadratureCompound"):
name='PoissonLogNormalQuadratureCompound'):
"""Constructs the PoissonLogNormalQuadratureCompound`.
Note: `probs` returned by (optional) `quadrature_fn` are presumed to be
Expand All @@ -240,14 +243,13 @@ def __init__(self,
quadrature_fn: Python callable taking `loc`, `scale`,
`quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
representing the LogNormal grid and corresponding normalized weight.
normalized) weight.
Default value: `quadrature_scheme_lognormal_quantiles`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
statistics (e.g., mean, mode, variance) use the value '`NaN`' to
indicate the result is undefined. When `False`, an exception is raised
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
Expand All @@ -259,52 +261,68 @@ def __init__(self,
parameters = dict(locals())
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype([loc, scale], tf.float32)
if loc is not None:
loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
if scale is not None:
scale = tf.convert_to_tensor(scale, dtype=dtype, name="scale")
self._quadrature_grid, self._quadrature_probs = tuple(quadrature_fn(
loc, scale, quadrature_size, validate_args))

dt = self._quadrature_grid.dtype
if not dtype_util.base_equal(dt, self._quadrature_probs.dtype):
raise TypeError("Quadrature grid dtype ({}) does not match quadrature "
"probs dtype ({}).".format(
dtype_util.name(dt),
dtype_util.name(self._quadrature_probs.dtype)))

self._distribution = poisson.Poisson(
log_rate=self._quadrature_grid,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats)

self._mixture_distribution = categorical.Categorical(
logits=tf.math.log(self._quadrature_probs),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats)
self._loc = tensor_util.convert_nonref_to_tensor(
loc, name='loc', dtype=dtype)
self._scale = tensor_util.convert_nonref_to_tensor(
scale, name='scale', dtype=dtype)
self._quadrature_fn = quadrature_fn
dtype_util.assert_same_float_dtype([self._loc, self._scale])

self._loc = loc
self._scale = scale
self._quadrature_size = quadrature_size

super(PoissonLogNormalQuadratureCompound, self).__init__(
dtype=dt,
dtype=dtype,
reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[loc, scale],
name=name)

def poisson_and_mixture_distributions(self):
"""Returns the Poisson and Mixture distribution parameterized by the quadrature grid and weights."""
loc = tf.convert_to_tensor(self.loc)
scale = tf.convert_to_tensor(self.scale)
quadrature_grid, quadrature_probs = tuple(self._quadrature_fn(
loc, scale, self.quadrature_size, self.validate_args))
dt = quadrature_grid.dtype
if not dtype_util.base_equal(dt, quadrature_probs.dtype):
raise TypeError('Quadrature grid dtype ({}) does not match quadrature '
'probs dtype ({}).'.format(
dtype_util.name(dt),
dtype_util.name(quadrature_probs.dtype)))

dist = poisson.Poisson(
log_rate=quadrature_grid,
validate_args=self.validate_args,
allow_nan_stats=self.allow_nan_stats)

mixture_dist = categorical.Categorical(
logits=tf.math.log(quadrature_probs),
validate_args=self.validate_args,
allow_nan_stats=self.allow_nan_stats)
return dist, mixture_dist

@property
@deprecation.deprecated(
'2019-11-01',
('The `mixture_distribution` property will be removed. '
'Use `poisson_and_mixture_distributions` instead.'),
warn_once=True)
def mixture_distribution(self):
"""Distribution which randomly selects a Poisson with quadrature param."""
return self._mixture_distribution
_, mixture_dist = self.poisson_and_mixture_distributions()
return mixture_dist

@property
@deprecation.deprecated(
'2019-11-01',
('The `distribution` property will be removed. '
'Use `poisson_and_mixture_distributions` instead.'),
warn_once=True)
def distribution(self):
"""Base Poisson parameterized by a quadrature grid."""
return self._distribution
dist, _ = self.poisson_and_mixture_distributions()
return dist

@property
def loc(self):
Expand All @@ -320,40 +338,47 @@ def scale(self):
def quadrature_size(self):
return self._quadrature_size

def _batch_shape_tensor(self):
def _batch_shape_tensor(self, distributions=None):
if distributions is None:
distributions = self.poisson_and_mixture_distributions()
dist, mixture_dist = distributions
return tf.broadcast_dynamic_shape(
self.distribution.batch_shape_tensor(),
tf.shape(self.mixture_distribution.logits))[:-1]
dist.batch_shape_tensor(),
prefer_static.shape(mixture_dist.logits))[:-1]

def _batch_shape(self):
dist, mixture_dist = self.poisson_and_mixture_distributions()
return tf.broadcast_static_shape(
self.distribution.batch_shape,
self.mixture_distribution.logits.shape)[:-1]
dist.batch_shape,
mixture_dist.logits.shape)[:-1]

def _event_shape(self):
return tf.TensorShape([])

def _sample_n(self, n, seed=None):
# Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
# ids as a [n]-shaped vector.
distributions = self.poisson_and_mixture_distributions()
dist, mixture_dist = distributions
batch_size = tensorshape_util.num_elements(self.batch_shape)
if batch_size is None:
batch_size = tf.reduce_prod(self.batch_shape_tensor())
# We need to "sample extra" from the mixture distribution if it doesn't
batch_size = tf.reduce_prod(
self._batch_shape_tensor(distributions=distributions))
# We need to 'sample extra' from the mixture distribution if it doesn't
# already specify a probs vector for each batch coordinate.
# We only support this kind of reduced broadcasting, i.e., there is exactly
# one probs vector for all batch dims or one for each.
stream = seed_stream.SeedStream(
seed, salt="PoissonLogNormalQuadratureCompound")
ids = self._mixture_distribution.sample(
seed, salt='PoissonLogNormalQuadratureCompound')
ids = mixture_dist.sample(
sample_shape=concat_vectors(
[n],
distribution_util.pick_vector(
self.mixture_distribution.is_scalar_batch(),
mixture_dist.is_scalar_batch(),
[batch_size],
np.int32([]))),
seed=stream())
# We need to flatten batch dims in case mixture_distribution has its own
# We need to flatten batch dims in case mixture_dist has its own
# batch dims.
ids = tf.reshape(
ids,
Expand All @@ -369,20 +394,25 @@ def _sample_n(self, n, seed=None):
delta=self._quadrature_size,
dtype=ids.dtype)
ids += offset
rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids)
rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids)
rate = tf.reshape(
rate, shape=concat_vectors([n], self.batch_shape_tensor()))
rate, shape=concat_vectors([n], self._batch_shape_tensor(
distributions=distributions)))
return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)

def _log_prob(self, x):
return tf.reduce_logsumexp((self.mixture_distribution.logits +
self.distribution.log_prob(x[..., tf.newaxis])),
dist, mixture_dist = self.poisson_and_mixture_distributions()
return tf.reduce_logsumexp((mixture_dist.logits +
dist.log_prob(x[..., tf.newaxis])),
axis=-1)

def _mean(self):
def _mean(self, distributions=None):
if distributions is None:
distributions = self.poisson_and_mixture_distributions()
dist, mixture_dist = distributions
return tf.exp(
tf.reduce_logsumexp(
self.mixture_distribution.logits + self.distribution.log_rate,
mixture_dist.logits + dist.log_rate,
axis=-1))

def _variance(self):
Expand All @@ -398,25 +428,28 @@ def _log_variance(self):
#
# where,
#
# Z|v ~ interpolate_affine[v](distribution)
# V ~ mixture_distribution
# Z|v ~ interpolate_affine[v](dist)
# V ~ mixture_dist
#
# thus,
#
# E[Var[Z | V]] = sum{ prob[d] Var[d] : d=0, ..., deg-1 }
# Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 }
distributions = self.poisson_and_mixture_distributions()
dist, mixture_dist = distributions
v = tf.stack(
[
# log(self.distribution.variance()) = log(Var[d]) = log(rate[d])
self.distribution.log_rate,
# log(dist.variance()) = log(Var[d]) = log(rate[d])
dist.log_rate,
# log((Mean[d] - Mean)**2)
2. * tf.math.log(
tf.abs(self.distribution.mean() -
self._mean()[..., tf.newaxis])),
tf.abs(
dist.mean() -
self._mean(distributions=distributions)[..., tf.newaxis])),
],
axis=-1)
return tf.reduce_logsumexp(
self.mixture_distribution.logits[..., tf.newaxis] + v, axis=[-2, -1])
mixture_dist.logits[..., tf.newaxis] + v, axis=[-2, -1])


def concat_vectors(*args):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,32 @@ def testMeanVarianceBroadcastBoth(self):
self.run_test_sample_consistent_mean_variance(
self.evaluate, pln, rtol=0.1, atol=0.01)

def testGradientThroughParams(self):
pln = tfd.PoissonLogNormalQuadratureCompound(
loc=tf.Variable([0., -0.5], shape=[2] if self.static_shape
else None),
scale=tf.Variable([1., 0.9], shape=[2] if self.static_shape
else None),
quadrature_size=10, validate_args=True)
with tf.GradientTape() as tape:
loss = -pln.log_prob([1., 2.])
grad = tape.gradient(loss, pln.trainable_variables)
self.assertLen(grad, 2)
self.assertNotIn(None, grad)

def testGradientThroughNonVariableParams(self):
pln = tfd.PoissonLogNormalQuadratureCompound(
loc=tf.convert_to_tensor([0., -0.5]),
scale=tf.convert_to_tensor([1., 0.9]),
quadrature_size=10, validate_args=True)
with tf.GradientTape() as tape:
tape.watch(pln.loc)
tape.watch(pln.scale)
loss = -pln.log_prob([1., 2.])
grad = tape.gradient(loss, [pln.loc, pln.scale])
self.assertLen(grad, 2)
self.assertNotIn(None, grad)


@test_util.run_all_in_graph_and_eager_modes
class PoissonLogNormalQuadratureCompoundStaticShapeTest(
Expand All @@ -115,5 +141,5 @@ def static_shape(self):
return False


if __name__ == "__main__":
if __name__ == '__main__':
tf.test.main()

0 comments on commit 1fdd1ff

Please sign in to comment.