Skip to content

Commit

Permalink
Merge pull request tensorflow#795 from jeffpollock9:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 299920215
  • Loading branch information
tensorflower-gardener committed Mar 9, 2020
2 parents 3c698e9 + b8527d5 commit 85a1874
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
# batch slicing.
INSTANTIABLE_BUT_NOT_SLICABLE = (
'BatchReshape',
'OrderedLogistic', # b/149597503
)

EXTRA_TENSOR_CONVERSION_DISTS = {
Expand All @@ -131,7 +130,6 @@
'DirichletMultinomial', # No converter for TensorListFromTensor
'Gamma', # No converter for While
'Multinomial', # No converter for TensorListFromTensor
'OrderedLogistic', # No converter for SparseSoftmaxCrossEntropyWithLogits
'PlackettLuce', # No converter for TopKV2
'TruncatedNormal', # No converter for ParameterizedTruncatedNormal
'VonMises', # No converter for While
Expand All @@ -140,7 +138,6 @@
]

LOGPROB_AUTOVECTORIZATION_IS_BROKEN = [
'OrderedLogistic', # No converter for SparseSoftmaxCrossEntropyWithLogits
'StudentT', # Numerical problem: b/149785284
'HalfStudentT', # Numerical problem: b/149785284
'TruncatedNormal', # Numerical problem: b/150811273
Expand Down
69 changes: 63 additions & 6 deletions tensorflow_probability/python/distributions/ordered_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,30 @@
# QuantizedDistribution.


def _broadcast_cat_event_and_params(event, params, base_dtype):
"""Broadcasts the event or distribution parameters."""
if dtype_util.is_floating(event.dtype):
# When `validate_args=True` we've already ensured int/float casting
# is closed.
event = tf.cast(event, dtype=tf.int32)
elif not dtype_util.is_integer(event.dtype):
raise TypeError('`value` should have integer `dtype` or '
'`self.dtype` ({})'.format(base_dtype))
shape_known_statically = (
tensorshape_util.rank(params.shape) is not None and
tensorshape_util.is_fully_defined(params.shape[:-1]) and
tensorshape_util.is_fully_defined(event.shape))
if not shape_known_statically or params.shape[:-1] != event.shape:
params = params * tf.ones_like(event[..., tf.newaxis],
dtype=params.dtype)
params_shape = tf.shape(params)[:-1]
event = event * tf.ones(params_shape, dtype=event.dtype)
if tensorshape_util.rank(params.shape) is not None:
tensorshape_util.set_shape(event, params.shape[:-1])

return event, params


class OrderedLogistic(distribution.Distribution):
"""Ordered logistic distribution.
Expand Down Expand Up @@ -280,17 +304,50 @@ def _event_shape(self):
def _log_prob(self, x):
# TODO(b/149334734): Consider using QuantizedDistribution for the log_prob
# computation for better precision.
log_survival_xm1 = self._log_survival_function(x - 1)
log_survival_x = self._log_survival_function(x)
return tfp_math.log_sub_exp(log_survival_xm1, log_survival_x)
num_categories = self._num_categories()
x, augmented_log_survival = _broadcast_cat_event_and_params(
event=x,
params=tf.math.log_sigmoid(
self.loc[..., tf.newaxis] - self._augmented_cutpoints()),
base_dtype=dtype_util.base_dtype(self.dtype))
x_flat = tf.reshape(x, [-1, 1])
augmented_log_survival_flat = tf.reshape(
augmented_log_survival, [-1, num_categories + 1])
log_survival_flat_xm1 = tf.gather(
params=augmented_log_survival_flat,
indices=tf.clip_by_value(x_flat, 0, num_categories),
batch_dims=1)
log_survival_flat_x = tf.gather(
params=augmented_log_survival_flat,
indices=tf.clip_by_value(x_flat + 1, 0, num_categories),
batch_dims=1)
log_prob_flat = tfp_math.log_sub_exp(
log_survival_flat_xm1, log_survival_flat_x)
# Deal with case where both survival probabilities are -inf, which gives
# `log_prob_flat = nan` when it should be -inf.
minus_inf = tf.constant(-np.inf, dtype=log_prob_flat.dtype)
log_prob_flat = tf.where(
x_flat > num_categories - 1, minus_inf, log_prob_flat)
return tf.reshape(log_prob_flat, shape=tf.shape(x))

def _log_cdf(self, x):
return tfp_math.log1mexp(self._log_survival_function(x))

def _log_survival_function(self, x):
return tf.math.log_sigmoid(
self.loc -
tf.gather(self._augmented_cutpoints(), x + 1, axis=-1))
num_categories = self._num_categories()
x, augmented_log_survival = _broadcast_cat_event_and_params(
event=x,
params=tf.math.log_sigmoid(
self.loc[..., tf.newaxis] - self._augmented_cutpoints()),
base_dtype=dtype_util.base_dtype(self.dtype))
x_flat = tf.reshape(x, [-1, 1])
augmented_log_survival_flat = tf.reshape(
augmented_log_survival, [-1, num_categories + 1])
log_survival_flat = tf.gather(
params=augmented_log_survival_flat,
indices=tf.clip_by_value(x_flat + 1, 0, num_categories),
batch_dims=1)
return tf.reshape(log_survival_flat, shape=tf.shape(x))

def _entropy(self):
log_probs = self.categorical_log_probs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,32 @@ def testBatchShapes(self, test, batch_shape):
self.assertAllEqual(dist.event_shape, [])
self.assertAllEqual(self.evaluate(dist.event_shape_tensor()), [])

log_probs_shape = tf.shape(dist.categorical_log_probs())
self.assertAllEqual(self.evaluate(log_probs_shape), batch_shape + [3])
categorical_probs = dist.categorical_probs()
categorical_probs_shape = tf.shape(categorical_probs)
self.assertAllEqual(
self.evaluate(categorical_probs_shape), batch_shape + [3])

sample_shape = tf.shape(dist.sample(seed=test_util.test_seed()))
sample = dist.sample(seed=test_util.test_seed())
sample_shape = tf.shape(sample)
self.assertAllEqual(self.evaluate(sample_shape), batch_shape)

sample_shape_n = tf.shape(
dist.sample([4, 5], seed=test_util.test_seed()))
self.assertAllEqual(self.evaluate(sample_shape_n), [4, 5] + batch_shape)
prob_sample_shape = tf.shape(dist.prob(sample))
survival_prob_sample_shape = tf.shape(dist.survival_function(sample))
self.assertAllEqual(self.evaluate(prob_sample_shape), batch_shape)
self.assertAllEqual(self.evaluate(survival_prob_sample_shape), batch_shape)

def testProbs(self):
n = [4, 5]
sample_n = dist.sample(n, seed=test_util.test_seed())
sample_n_shape = tf.shape(sample_n)
self.assertAllEqual(self.evaluate(sample_n_shape), n + batch_shape)

prob_sample_n_shape = tf.shape(dist.prob(sample_n))
survival_prob_sample_n_shape = tf.shape(dist.survival_function(sample_n))
self.assertAllEqual(self.evaluate(prob_sample_n_shape), n + batch_shape)
self.assertAllEqual(
self.evaluate(survival_prob_sample_n_shape), n + batch_shape)

def testProbs(self):
# survival functions
# P(Y > 0) = sigmoid(1) = 0.7310586
# P(Y > 1) = sigmoid(0) = 0.5
Expand All @@ -90,6 +104,7 @@ def testProbs(self):
# P(Y = 2) = sigmoid(0) - sigmoid(-1) = 0.23105857
# P(Y = 3) = sigmoid(-1) = 0.26894143
expected_probs = [0.2689414, 0.2310586, 0.23105857, 0.26894143]
expected_survival_probs = 1. - np.cumsum(expected_probs)
dist = tfd.OrderedLogistic(cutpoints=[-1., 0., 1.], loc=0.)

categorical_probs = self.evaluate(dist.categorical_probs())
Expand All @@ -98,6 +113,17 @@ def testProbs(self):
probs = np.flip(self.evaluate(dist.prob([3, 2, 1, 0])))
self.assertAllClose(expected_probs, probs, atol=1e-6)

survival_probs = self.evaluate(dist.survival_function([0, 1, 2, 3]))
self.assertAllClose(expected_survival_probs, survival_probs, atol=1e-6)

zero_probs = self.evaluate(dist.prob([-1, 4]))
self.assertAllClose([0., 0.], zero_probs, atol=1e-6)

out_of_bounds_survival_probs = self.evaluate(
dist.survival_function([-2, -1, 4, 5]))
self.assertAllClose(
[1., 1., 0., 0.], out_of_bounds_survival_probs, atol=1e-6)

def testMode(self):
# 2 cutpoints i.e. 3 possible outcomes. 3 "batched" distributions with the
# logistic distribution location well within the large cutpoint spacing so
Expand Down

0 comments on commit 85a1874

Please sign in to comment.