Skip to content

Commit

Permalink
Rename interpolate_nondiscrete flag to `force_probs_to_zero_outside…
Browse files Browse the repository at this point in the history
…_support` (with the opposite sense) in Zipf.

The second name leaves room for other distributions to use the same
flag to control tf.where gates on extrapolating the support to, e.g.,
negative arguments.

Not changing the pre-existing tf.where gate in Zipf.cdf to actually be
controlled by said new flag, because that would change the behavior of
Zipf with no flags set.  That can be done another time.

PiperOrigin-RevId: 340878676
  • Loading branch information
axch authored and tensorflower-gardener committed Nov 5, 2020
1 parent 0e3eaa3 commit b667ee9
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
62 changes: 53 additions & 9 deletions tensorflow_probability/python/distributions/zipf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import


__all__ = [
Expand Down Expand Up @@ -59,9 +60,16 @@ class Zipf(distribution.Distribution):
supported in the current implementation.
"""

@deprecation.deprecated_args(
'2021-02-10',
('The `interpolate_nondiscrete` flag is deprecated; instead use '
'`force_probs_to_zero_outside_support` (with the opposite sense).'),
'interpolate_nondiscrete',
warn_once=True)
def __init__(self,
power,
dtype=tf.int32,
force_probs_to_zero_outside_support=None,
interpolate_nondiscrete=True,
sample_maximum_iterations=100,
validate_args=False,
Expand All @@ -74,7 +82,19 @@ def __init__(self,
strictly greater than `1`.
dtype: The `dtype` of `Tensor` returned by `sample`.
Default value: `tf.int32`.
interpolate_nondiscrete: Python `bool`. When `False`, `log_prob` returns
force_probs_to_zero_outside_support: Python `bool`. When `True`,
non-integer values are evaluated "strictly": `log_prob` returns
`-inf`, `prob` returns `0`, and `cdf` and `sf` correspond. When
`False`, the implementation is free to save computation (and TF graph
size) by evaluating something that matches the Zipf pmf at integer
values `k` but produces an unrestricted result on other inputs. In the
case of Zipf, the `log_prob` formula in this case happens to be the
continuous function `-power log(k) - log(zeta(power))`. Note that this
function is not itself a normalized probability log-density.
Default value: `False`.
interpolate_nondiscrete: Deprecated. Use
`force_probs_to_zero_outside_support` (with the opposite sense) instead.
Python `bool`. When `False`, `log_prob` returns
`-inf` (and `prob` returns `0`) for non-integer inputs. When `True`,
`log_prob` evaluates the continuous function `-power log(k) -
log(zeta(power))` , which matches the Zipf pmf at integer arguments `k`
Expand Down Expand Up @@ -114,6 +134,18 @@ def __init__(self,
'power.dtype ({}) is not a supported `float` type.'.format(
dtype_util.name(self._power.dtype)))
self._interpolate_nondiscrete = interpolate_nondiscrete
if force_probs_to_zero_outside_support is not None:
# `force_probs_to_zero_outside_support` was explicitly set, so it
# controls.
self._force_probs_to_zero_outside_support = (
force_probs_to_zero_outside_support)
elif not self._interpolate_nondiscrete:
# `interpolate_nondiscrete` was explicitly set by the caller, so it
# should control until it is removed.
self._force_probs_to_zero_outside_support = True
else:
# Default.
self._force_probs_to_zero_outside_support = False
self._sample_maximum_iterations = sample_maximum_iterations
super(Zipf, self).__init__(
dtype=dtype,
Expand All @@ -140,10 +172,20 @@ def power(self):
return self._power

@property
@deprecation.deprecated(
'2021-02-10',
('The `interpolate_nondiscrete` property is deprecated; instead use '
'`force_probs_to_zero_outside_support` (with the opposite sense).'),
warn_once=True)
def interpolate_nondiscrete(self):
"""Interpolate (log) probs on non-integer inputs."""
return self._interpolate_nondiscrete

@property
def force_probs_to_zero_outside_support(self):
"""Return 0 probabilities on non-integer inputs."""
return self._force_probs_to_zero_outside_support

@property
def sample_maximum_iterations(self):
"""Maximum number of allowable iterations in `sample`."""
Expand All @@ -166,14 +208,15 @@ def _log_prob(self, x, power=None):
# where Z is the normalization constant. For x < 1 and non-integer points,
# the log-probability is -inf.
#
# However, if interpolate_nondiscrete is True, we return the natural
# continuous relaxation for x >= 1 which agrees with the log probability at
# positive integer points.
# However, if force_probs_to_zero_outside_support is False, we return the
# natural continuous relaxation for x >= 1 which agrees with the log
# probability at positive integer points.
power = power if power is not None else tf.convert_to_tensor(self.power)
x = tf.cast(x, power.dtype)
log_normalization = tf.math.log(tf.math.zeta(power, 1.))

safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 1.)
safe_x = tf.maximum(
tf.floor(x) if self.force_probs_to_zero_outside_support else x, 1.)
y = -power * tf.math.log(safe_x)
log_unnormalized_prob = tf.where(
tf.equal(x, safe_x), y, dtype_util.as_numpy_dtype(y.dtype)(-np.inf))
Expand All @@ -187,11 +230,12 @@ def _cdf(self, x):
# For fractional x, the CDF is equal to the CDF at n = floor(x).
# For x < 1, the CDF is zero.

# If interpolate_nondiscrete is True, we return a continuous relaxation
# which agrees with the CDF at integer points.
# If force_probs_to_zero_outside_support is False, we return a continuous
# relaxation which agrees with the CDF at integer points.
power = tf.convert_to_tensor(self.power)
x = tf.cast(x, power.dtype)
safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.)
safe_x = tf.maximum(
tf.floor(x) if self.force_probs_to_zero_outside_support else x, 0.)

cdf = 1. - (
tf.math.zeta(power, safe_x + 1.) / tf.math.zeta(power, 1.))
Expand Down Expand Up @@ -345,7 +389,7 @@ def _sample_control_dependencies(self, x):
return assertions
assertions.append(assert_util.assert_non_negative(
x, message='samples must be non-negative'))
if not self.interpolate_nondiscrete:
if self.force_probs_to_zero_outside_support:
assertions.append(distribution_util.assert_integer_form(
x, message='samples cannot contain fractional components.'))
return assertions
14 changes: 10 additions & 4 deletions tensorflow_probability/python/distributions/zipf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ def testInvalidEventDtype(self):
def testZipfLogPmf_InvalidArgs(self):
power = tf.constant([4.0])
# Non-integer samples are rejected if validate_args is True and
# interpolate_nondiscrete is False.
# force_probs_to_zero_outside_support is True.
zipf = tfd.Zipf(
power=power, interpolate_nondiscrete=False, validate_args=True)
power=power,
force_probs_to_zero_outside_support=True,
validate_args=True)
non_integer_samples = [0.99, 4.5, 5.001, 1e-5]
for x in non_integer_samples:

Expand Down Expand Up @@ -165,7 +167,9 @@ def testZipfLogPmf_NonIntegerArgsNoInterpolation(self):
x = [-3., -0.5, 0., 2., 2.2, 3., 3.1, 4., 5., 5.5, 6., 7.2]

zipf = tfd.Zipf(
power=power, interpolate_nondiscrete=False, validate_args=False)
power=power,
force_probs_to_zero_outside_support=True,
validate_args=False)
log_pmf = zipf.log_prob(x)
self.assertEqual((batch_size,), log_pmf.shape)

Expand Down Expand Up @@ -236,7 +240,9 @@ def testZipfCdf_NonIntegerArgsNoInterpolation(self):
x = [-3.5, -0.5, 0., 1, 1.1, 2.2, 3.1, 4., 5., 5.5, 6.4, 7.8]

zipf = tfd.Zipf(
power=power, interpolate_nondiscrete=False, validate_args=False)
power=power,
force_probs_to_zero_outside_support=True,
validate_args=False)
log_cdf = zipf.log_cdf(x)
self.assertEqual((batch_size,), log_cdf.shape)
self.assertAllClose(self.evaluate(log_cdf), stats.zipf.logcdf(x, power_v))
Expand Down

0 comments on commit b667ee9

Please sign in to comment.