Skip to content

Commit

Permalink
Add a Hypothesis test that log_prob(sample) does not violate sample v…
Browse files Browse the repository at this point in the history
…alidation and is never nan.

There are several legitimate exceptions to this invariant:

- Bijectors introduced by TransformedDistribution can introduce nans;
  in at least one case (SoftmaxCentered) this cannot be avoided by
  changing the bijector.  TransformedDistribution is therefore
  excluded.  Testing that Bijectors do not introduce unnecessary nans
  is deferred to a later change.

- The meta-distributions Sample and Independent can combine +-inf
  log densities from the base distribution into nan.

- Mixtures (whether same family or not) can violate sample validation
  by showing a sample from mixture component A to mixture component B,
  which may have a different support.

- QuantizedDistribution's log_prob depends on the cdf of the
  underlying distribution.  Testing that cdfs do not introduce
  unnecessary nans is deferred to a later change, whereupon
  QuantizedDistribution can be re-enabled.

Besides that, found 19 other violations, of which 8 are small fixes in
this CL, 4 look like small-ish fixes TBD soon, and 7 filed as bugs in
their own right (two of which have been fixed while this CL was in
flight).

The fixes:
- Use multiply_no_nan in ProbitBernoulli, same as in Bernoulli
- Mask out edge-case inputs to GammaGamma, GeneralizedExtremeValue,
  LogLogistic, and LogNormal.  Masks register as ok because `log` is
  effectively discontinuous at 0 and +inf.
- Cause VonMisesFisher to emit samples on its support manifold, by
  brute-force normalizing them.
- Cause Pareto to emit samples on its support manifold, with slightly
  more careful numerics in the sampler.
- Adjust the computation of GeneralizedExtremeValue's support to
  numerically agree with the sampler.
- Fix typos in documentation of BetaQuotient.

PiperOrigin-RevId: 355074224
  • Loading branch information
axch authored and tensorflower-gardener committed Feb 2, 2021
1 parent b331bee commit cbbbad0
Show file tree
Hide file tree
Showing 15 changed files with 192 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def testGermanCreditHMC(self):
num_steps=4000,
num_leapfrog_steps=15,
step_size=0.03,
standard_deviation_fudge_atol=5e-4, # b/179074257
)


Expand Down
15 changes: 11 additions & 4 deletions tensorflow_probability/python/bijectors/gev_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,18 @@ def _maybe_assert_valid_x(self, x, loc=None, scale=None, concentration=None):
concentration = (
tf.convert_to_tensor(self.concentration) if concentration is None else
concentration)
# We intentionally compute the boundary with (1.0 / concentration) * scale
# instead of just scale / concentration.
# Why? The sampler returns loc + (foo / concentration) * scale,
# and at high-ish values of concentration, foo has a decent
# probability of being numerically exactly -1. We therefore mimic
# the pattern of round-off that occurs in the sampler to make sure
# that samples emitted from this distribution will pass its own
# validations. This is sometimes necessary: in TF's float32,
# 0.69314826 / 37.50019 < (1.0 / 37.50019) * 0.69314826
boundary = loc - (1.0 / concentration) * scale
# The support of this bijector depends on the sign of concentration.
is_in_bounds = tf.where(
concentration > 0.,
x >= loc - scale / concentration,
x <= loc - scale / concentration)
is_in_bounds = tf.where(concentration > 0., x >= boundary, x <= boundary)
# For concentration 0, the domain is the whole line.
is_in_bounds = is_in_bounds | tf.math.equal(concentration, 0.)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,9 @@ def gev_constraint(loc, scale, conc):
"""Maps `s` to support based on `loc`, `scale` and `conc`."""
def constrain(x):
c = tf.convert_to_tensor(conc)
endpoint = loc - scale / c
# We intentionally compute the endpoint with (1.0 / concentration) * scale,
# for the same reason as in GeneralizedExtremeValueCDF._maybe_assert_valid_x
endpoint = loc - (1.0 / c) * scale
return tf.where(c > 0.,
tf.math.softplus(x) + endpoint,
tf.where(
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_probability/python/distributions/beta_quotient.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ class BetaQuotient(distribution.Distribution):
```none
X ~ Beta(a0, b0)
Y ~ Beta(a1, b1)
X / Y ~ BetaQuotient(a0, a1, b0, b1)
X / Y ~ BetaQuotient(a0, b0, a1, b1)
```
The distribution is defined over the positive reals, by four parameters
`concentration0_numerator`, `concentration1_numerator`,
`concentration0_denominator` and `concentration1_denominator`
(aka `alpha` and `beta` of the numerator and denominator Beta distribution
(aka `beta` and `alpha` of the numerator and denominator Beta distribution
respectively).
Distribution parameters are automatically broadcast in all functions; see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,40 @@
)


NO_NANS_IN_SAMPLE_TEST_BLOCK_LIST = (
'ContinuousBernoulli', # b/169321398
NO_NANS_TEST_BLOCK_LIST = (
'BetaQuotient', # b/178925774
'Dirichlet', # b/169689852
'ExpRelaxedOneHotCategorical', # b/169663302
# Independent log_prob unavoidably emits `nan` if the underlying
# distribution yields a +inf on one sample and a -inf on another.
'Independent',
'InverseGaussian', # TODO(axch): Fix numerics of 1 - sqrt(x + 1)
'LogitNormal', # TODO(axch): Maybe nan problem hints at accuracy problem
# Mixtures of component distributions whose samples have different dtypes
# cannot pass validate_args.
'Mixture',
'MixtureSameFamily', # b/169790025
'OneHotCategorical', # b/169680869
# TODO(axch) Re-enable after CDFs of underlying distributions are tested
# for NaN production.
'QuantizedDistribution',
# Sample log_prob unavoidably emits `nan` if the underlying distribution
# yields a +inf on one sample and a -inf on another. This can happen even
# with iid samples, if one sample is at a pole of the distribution, and the
# other is far enough into the tail to round to +inf. Weibull with a low
# concentration is an example of a distribution that can produce this
# effect.
'Sample',
'TransformedDistribution', # Bijectors may introduce nans
# TODO(axch): Fix numerics of _cauchy_cdf(x + delta) - _cauchy_cdf(x)
'TruncatedCauchy',
# TODO(axch): Edit C++ sampler to reject numerically out-of-bounds samples
'TruncatedNormal',
)

NANS_EVEN_IN_SAMPLE_LIST = (
'Mixture', # b/169847344. Not a nan, but can't always sample from Mixture
'TransformedDistribution', # Bijectors may introduce nans
'Zipf', # b/175929563 triggered by Zipf inside Independent
)

# Batch slicing requires implementing `_params_event_ndims`. Generic
Expand Down Expand Up @@ -160,7 +189,7 @@ def testDistribution(self, dist_name, data):


@test_util.test_all_tf_execution_regimes
class NoNansInSampleTest(test_util.TestCase):
class NoNansTest(test_util.TestCase, dhps.TestCase):

@parameterized.named_parameters(
{'testcase_name': dname, 'dist_name': dname}
Expand All @@ -169,45 +198,53 @@ class NoNansInSampleTest(test_util.TestCase):
@hp.given(hps.data())
@tfp_hps.tfp_hp_settings()
def testDistribution(self, dist_name, data):
if dist_name in NO_NANS_IN_SAMPLE_TEST_BLOCK_LIST:
if dist_name in NO_NANS_TEST_BLOCK_LIST:
self.skipTest('{} is blocked'.format(dist_name))
def eligibility_filter(name):
return name not in NO_NANS_IN_SAMPLE_TEST_BLOCK_LIST
return name not in NO_NANS_TEST_BLOCK_LIST
dist = data.draw(dhps.distributions(dist_name=dist_name, enable_vars=False,
eligibility_filter=eligibility_filter))
samples = self.check_samples_not_nan(dist)
self.assume_loc_scale_ok(dist)

hp.note('Testing on samples {}'.format(samples))
with tfp_hps.no_tf_rank_errors():
lp = self.evaluate(dist.log_prob(samples))
self.assertAllEqual(np.zeros_like(lp), np.isnan(lp))

@parameterized.named_parameters(
{'testcase_name': dname, 'dist_name': dname}
for dname in sorted(NO_NANS_TEST_BLOCK_LIST)
if dname not in NANS_EVEN_IN_SAMPLE_LIST)
@hp.given(hps.data())
@tfp_hps.tfp_hp_settings()
def testSampleOnly(self, dist_name, data):
def eligibility_filter(name):
# We use this eligibility filter to focus the test's attention
# on sub-distributions that are not already tested by the sample
# and log_prob test. However, we also include some relatively
# widely used distributions to make sure that at least one legal
# sub-distribution exists for every meta-distribution we may be
# testing.
return ((name in NO_NANS_TEST_BLOCK_LIST
or name in dhps.QUANTIZED_BASE_DISTS)
and name not in NANS_EVEN_IN_SAMPLE_LIST)
dist = data.draw(dhps.distributions(dist_name=dist_name, enable_vars=False,
eligibility_filter=eligibility_filter))
self.check_samples_not_nan(dist)

def check_samples_not_nan(self, dist):
hp.note('Trying distribution {}'.format(
self.evaluate_dict(dist.parameters)))
seed = test_util.test_seed(sampler_type='stateless')
with tfp_hps.no_tf_rank_errors():
s1 = self.evaluate(dist.sample(20, seed=seed))
self.assertAllEqual(np.zeros_like(s1), np.isnan(s1))
samples = self.evaluate(dist.sample(20, seed=seed))
self.assertAllEqual(np.zeros_like(samples), np.isnan(samples))
return samples


@test_util.test_all_tf_execution_regimes
class EventSpaceBijectorsTest(test_util.TestCase):

def check_bad_loc_scale(self, dist):
if hasattr(dist, 'distribution'):
# BatchReshape, Independent, TransformedDistribution, and
# QuantizedDistribution
self.check_bad_loc_scale(dist.distribution)
if isinstance(dist, tfd.MixtureSameFamily):
self.check_bad_loc_scale(dist.mixture_distribution)
self.check_bad_loc_scale(dist.components_distribution)
if isinstance(dist, tfd.Mixture):
self.check_bad_loc_scale(dist.cat)
self.check_bad_loc_scale(dist.components)
if hasattr(dist, 'loc') and hasattr(dist, 'scale'):
try:
loc_ = tf.convert_to_tensor(dist.loc)
scale_ = tf.convert_to_tensor(dist.scale)
except (ValueError, TypeError):
# If they're not Tensor-convertible, don't try to check them. This is
# the case, in, for example, multivariate normal, where the scale is a
# `LinearOperator`.
return
loc, scale = self.evaluate([loc_, scale_])
hp.assume(np.all(np.abs(loc / scale) < 1e7))
class EventSpaceBijectorsTest(test_util.TestCase, dhps.TestCase):

def check_event_space_bijector_constrains(self, dist, data):
event_space_bijector = dist.experimental_default_event_space_bijector()
Expand Down Expand Up @@ -245,7 +282,7 @@ def testDistributionWithVars(self, dist_name, data):
dist = data.draw(dhps.base_distributions(
dist_name=dist_name, enable_vars=True))
self.evaluate([var.initializer for var in dist.variables])
self.check_bad_loc_scale(dist)
self.assume_loc_scale_ok(dist)
self.check_event_space_bijector_constrains(dist, data)

# TODO(b/146572907): Fix `enable_vars` for metadistributions and
Expand All @@ -263,7 +300,7 @@ def ok(name):
dist_name=dist_name, enable_vars=True,
eligibility_filter=ok))
self.evaluate([var.initializer for var in dist.variables])
self.check_bad_loc_scale(dist)
self.assume_loc_scale_ok(dist)
self.check_event_space_bijector_constrains(dist, data)


Expand Down
7 changes: 7 additions & 0 deletions tensorflow_probability/python/distributions/gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@ def _log_prob(self, x):
log_unnormalized_prob = (tf.math.xlogy(concentration - 1., x) -
(concentration + mixing_concentration) *
tf.math.log(x + mixing_rate))
# The formula computes `nan` for `x == +inf`. However, it shouldn't be too
# inaccurate for large finite `x`, because `x` only appears as `log(x)`, and
# `log` is effectively discountinuous at `+inf`.
log_unnormalized_prob = tf.where(
x >= np.inf,
tf.constant(-np.inf, dtype=log_unnormalized_prob.dtype),
log_unnormalized_prob)

return log_unnormalized_prob - log_normalization

Expand Down
5 changes: 4 additions & 1 deletion tensorflow_probability/python/distributions/gev.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ def _log_prob(self, x):
log_t = tf.where(equal_zero, -z,
-tf.math.log1p(z * safe_conc) / safe_conc)

return (conc + 1) * log_t - tf.exp(log_t) - tf.math.log(scale)
result = (conc + 1) * log_t - tf.exp(log_t) - tf.math.log(scale)
return tf.where(z * safe_conc <= -1.0,
tf.constant(-np.inf, dtype=result.dtype),
result)

def _mean(self):
conc = tf.convert_to_tensor(self.concentration)
Expand Down
64 changes: 58 additions & 6 deletions tensorflow_probability/python/distributions/hypothesis_testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@
# TODO(b/128518790): Eliminate / minimize the fudge factors in here.


def constrain_between_eps_and_one_minus_eps(eps=1e-6):
return lambda x: eps + (1 - 2 * eps) * tf.sigmoid(x)
def constrain_between_eps_and_one_minus_eps(eps0=1e-6, eps1=1e-6):
return lambda x: eps0 + (1 - (eps0 + eps1)) * tf.sigmoid(x)


def ensure_high_gt_low(low, high):
Expand Down Expand Up @@ -300,7 +300,9 @@ def fix_bates(d):
'RelaxedCategorical.probs':
tf.math.softmax,
'Zipf.power':
tfp_hps.softplus_plus_eps(1 + 1e-6), # strictly > 1
# Strictly > 1. See also b/175929563 (rejection sampler
# iterates too much and emits `nan` for powers too close to 1).
tfp_hps.softplus_plus_eps(1 + 1e-4),
'ContinuousBernoulli.probs':
tf.sigmoid,
'Geometric.logits': # TODO(b/128410109): re-enable down to -50
Expand All @@ -311,8 +313,12 @@ def fix_bates(d):
constrain_between_eps_and_one_minus_eps(),
'Binomial.probs':
tf.sigmoid,
# Constrain probs away from 0 to avoid immense samples.
# See b/178842153.
'NegativeBinomial.logits':
lambda x: tf.minimum(x, 15.),
'NegativeBinomial.probs':
tf.sigmoid,
constrain_between_eps_and_one_minus_eps(eps0=0., eps1=1e-6),
'Bernoulli.probs':
tf.sigmoid,
'PlackettLuce.scores':
Expand All @@ -324,8 +330,10 @@ def fix_bates(d):
'cutpoints':
# Permit values that aren't too large
lambda x: tfb.Ascending().forward(10 * tf.math.tanh(x)),
# Capping log_rate because of weird semantics of Poisson with very
# large rates (see b/178842153).
'log_rate':
lambda x: tf.maximum(x, -16.),
lambda x: tf.minimum(tf.maximum(x, -16.), 15.),
# Capping log_rate1 and log_rate2 to 15. This is because if both are large
# (meaning the rates are `inf`), then the Skellam distribution is undefined.
'log_rate1':
Expand Down Expand Up @@ -770,7 +778,7 @@ def base_distributions(draw,
initialization in slicing_test. If `False`, the returned parameters are
all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor`
`tfp.util.TransformedVariable`}.
eligibility_filter: Optional Python callable. Blacklists some Distribution
eligibility_filter: Optional Python callable. Blocks some Distribution
class names so they will not be drawn at the top level.
params: An optional set of Distribution parameters. If params are not
provided, Hypothesis will choose a set of parameters.
Expand Down Expand Up @@ -1492,3 +1500,47 @@ class names so they will not be drawn.
batch_shape, event_dim, enable_vars,
eligibility_filter, validate_args))
raise ValueError('Unknown Distribution name {}'.format(dist_name))


class TestCase(object):
"""Mixin for TestCase-type classes with Hypothesis-specific utilities."""

def assume_loc_scale_ok(self, dist):
"""Hypothesis assumption that `dist` has reasonable a location and scale.
To wit, `hp.assume` that location / scale < 1e7. Why this check? Because
by assumption we tend to think of samples as being near the location, with
nearness determined by the scale. If the scale is close to machine epsilon
near the location, then the distribution is close to being numerically
degenerate, and therefore not a good test case.
This assumption is currently only checked for (batch) scalar locations and
scales, and only if the distribution is parameterized with parameters named
`loc` and `scale`.
The assumption is applied recursively through meta distributions.
Args:
dist: TFP `Distribution` to check.
"""
if hasattr(dist, 'distribution'):
# BatchReshape, Independent, TransformedDistribution, and
# QuantizedDistribution
self.assume_loc_scale_ok(dist.distribution)
if isinstance(dist, tfd.MixtureSameFamily):
self.assume_loc_scale_ok(dist.mixture_distribution)
self.assume_loc_scale_ok(dist.components_distribution)
if isinstance(dist, tfd.Mixture):
self.assume_loc_scale_ok(dist.cat)
self.assume_loc_scale_ok(dist.components)
if hasattr(dist, 'loc') and hasattr(dist, 'scale'):
try:
loc_ = tf.convert_to_tensor(dist.loc)
scale_ = tf.convert_to_tensor(dist.scale)
except (ValueError, TypeError):
# If they're not Tensor-convertible, don't try to check them. This is
# the case, in, for example, multivariate normal, where the scale is a
# `LinearOperator`.
return
loc, scale = self.evaluate([loc_, scale_])
hp.assume(np.all(np.abs(loc / scale) < 1e7))
2 changes: 1 addition & 1 deletion tensorflow_probability/python/distributions/logitnormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self,
name='LogitNormal'):
"""Construct a logit-normal distribution.
The LogititNormal distribution models positive-valued random variables whose
The LogitNormal distribution models random variables between 0 and 1 whose
logit (i.e., sigmoid_inverse, i.e., `log(p) - log1p(-p)`) is normally
distributed with mean `loc` and standard deviation `scale`. It is
constructed as the sigmoid transformation, (i.e., `1 / (1 + exp(-x))`) of a
Expand Down
10 changes: 8 additions & 2 deletions tensorflow_probability/python/distributions/loglogistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,14 @@ def _log_prob(self, x):
scale = tf.convert_to_tensor(self.scale)
loc = tf.convert_to_tensor(self.loc)
log_z = self._log_z(x, loc=loc, scale=scale)
return (-tf.math.log(scale) - loc +
(1. - scale) * log_z - 2 * tf.math.softplus(log_z))
answer = (-tf.math.log(scale) - loc +
(1. - scale) * log_z - 2 * tf.math.softplus(log_z))
# The formula computes `nan` for `x == +inf`. However, it shouldn't be too
# inaccurate for large finite `x`, because `x` only appears as `log(x)`, and
# `log` is effectively discountinuous at `+inf`.
return tf.where(x >= np.inf,
tf.constant(-np.inf, dtype=answer.dtype),
answer)

def _log_cdf(self, x):
return -tf.math.softplus(-self._log_z(x))
Expand Down
12 changes: 12 additions & 0 deletions tensorflow_probability/python/distributions/lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def scale(self):
"""Distribution parameter for the pre-transformed standard deviation."""
return self.distribution.scale

def _log_prob(self, x):
answer = super(LogNormal, self)._log_prob(x)
# The formula inherited from TransformedDistribution computes `nan` for `x
# == 0`. However, there's hope that it's not too inaccurate for small
# finite `x`, because `x` only appears as `log(x)`, and `log` is effectively
# discontinuous at 0. Furthermore, the result should be dominated by the
# `log(x)**2` term, with no higher-order term that needs to be cancelled
# numerically.
return tf.where(tf.equal(x, 0.0),
tf.constant(-np.inf, dtype=answer.dtype),
answer)

def _mean(self):
return tf.exp(self.distribution.mean() + 0.5 * self.distribution.variance())

Expand Down
Loading

0 comments on commit cbbbad0

Please sign in to comment.