Skip to content

Commit

Permalink
Make tfd.Deterministic "tape safe."
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 253801257
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jun 18, 2019
1 parent d7476af commit 3ea6662
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 51 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ py_library(
"//tensorflow_probability/python/internal:distribution_util",
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:tensor_util",
"//tensorflow_probability/python/internal:tensorshape_util",
],
)
Expand Down Expand Up @@ -1425,6 +1426,7 @@ py_test(
# numpy dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:test_case",
],
)

Expand Down
99 changes: 58 additions & 41 deletions tensorflow_probability/python/distributions/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from tensorflow_probability.python.distributions import distribution
from tensorflow_probability.python.distributions import kullback_leibler
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util

__all__ = [
Expand All @@ -38,20 +38,6 @@
]


def _get_tol(tol, dtype, validate_args):
"""Gets a Tensor of type `dtype`, 0 if `tol` is None, validation optional."""
if tol is None:
return tf.convert_to_tensor(value=0, dtype=dtype)

tol = tf.convert_to_tensor(value=tol, dtype=dtype)
if validate_args:
tol = distribution_util.with_dependencies([
assert_util.assert_non_negative(
tol, message="Argument 'tol' must be non-negative")
], tol)
return tol


@six.add_metaclass(abc.ABCMeta)
class _BaseDeterministic(distribution.Distribution):
"""Base class for Deterministic distributions."""
Expand Down Expand Up @@ -103,18 +89,13 @@ def __init__(self,
"""
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype([loc, atol, rtol], dtype_hint=tf.float32)
loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
if is_vector and validate_args:
msg = "Argument loc must be at least rank 1."
if tensorshape_util.rank(loc.shape) is not None:
if tensorshape_util.rank(loc.shape) < 1:
raise ValueError(msg)
else:
loc = distribution_util.with_dependencies(
[assert_util.assert_rank_at_least(loc, 1, message=msg)], loc)
self._loc = loc
self._atol = _get_tol(atol, self._loc.dtype, validate_args)
self._rtol = _get_tol(rtol, self._loc.dtype, validate_args)
self._loc = tensor_util.convert_immutable_to_tensor(
loc, dtype_hint=dtype, name="loc")
self._atol = tensor_util.convert_immutable_to_tensor(
0 if atol is None else atol, dtype=dtype, name="atol")
self._rtol = tensor_util.convert_immutable_to_tensor(
0 if rtol is None else rtol, dtype=dtype, name="rtol")
self._is_vector = is_vector

super(_BaseDeterministic, self).__init__(
dtype=self._loc.dtype,
Expand All @@ -125,11 +106,12 @@ def __init__(self,
graph_parents=[self._loc, self._atol, self._rtol],
name=name)

# Avoid using the large broadcast with self.loc if possible.
if rtol is None:
self._slack = self.atol
else:
self._slack = self.atol + self.rtol * tf.abs(self.loc)
def _slack(self, loc):
# Avoid using the large broadcast with self.loc if possible.
if self.parameters["rtol"] is None:
return self.atol
else:
return self.atol + self.rtol * tf.abs(loc)

@property
def loc(self):
Expand Down Expand Up @@ -165,6 +147,31 @@ def _sample_n(self, n, seed=None):
tf.concat([[n], self.batch_shape_tensor(), self.event_shape_tensor()],
axis=0))

def _parameter_control_dependencies(self, is_init):
msg = "Argument loc must be at least rank 1."
if is_init:
if self._is_vector and tensorshape_util.rank(self.loc.shape) is not None:
if tensorshape_util.rank(self.loc.shape) < 1:
raise ValueError(msg)

if not self.validate_args:
return []

assertions = []

if is_init != tensor_util.is_mutable(self.loc) and self._is_vector:
assertions.append(
assert_util.assert_rank_at_least(self.loc, 1, message=msg))
if is_init != tensor_util.is_mutable(self.atol):
assertions.append(
assert_util.assert_non_negative(
self.atol, message="Argument 'atol' must be non-negative"))
if is_init != tensor_util.is_mutable(self.rtol):
assertions.append(
assert_util.assert_non_negative(
self.rtol, message="Argument 'rtol' must be non-negative"))
return assertions


class Deterministic(_BaseDeterministic):
"""Scalar `Deterministic` distribution on the real line.
Expand Down Expand Up @@ -258,10 +265,13 @@ def _params_event_ndims(cls):

def _batch_shape_tensor(self):
return tf.broadcast_dynamic_shape(
tf.shape(input=self.loc), tf.shape(input=self._slack))
tf.shape(self.loc),
tf.broadcast_dynamic_shape(tf.shape(self.atol), tf.shape(self.rtol)))

def _batch_shape(self):
return tf.broadcast_static_shape(self.loc.shape, self._slack.shape)
return tf.broadcast_static_shape(
self.loc.shape,
tf.broadcast_static_shape(self.atol.shape, self.rtol.shape))

def _event_shape_tensor(self):
return tf.constant([], dtype=tf.int32)
Expand All @@ -270,12 +280,14 @@ def _event_shape(self):
return tf.TensorShape([])

def _prob(self, x):
loc = tf.identity(self.loc)
# Enforces dtype of probability to be float, when self.dtype is not.
prob_dtype = self.dtype if self.dtype.is_floating else tf.float32
return tf.cast(tf.abs(x - self.loc) <= self._slack, dtype=prob_dtype)
return tf.cast(tf.abs(x - loc) <= self._slack(loc), dtype=prob_dtype)

def _cdf(self, x):
return tf.cast(x >= self.loc - self._slack, dtype=self.dtype)
loc = tf.identity(self.loc)
return tf.cast(x >= loc - self._slack(loc), dtype=self.dtype)


class VectorDeterministic(_BaseDeterministic):
Expand Down Expand Up @@ -375,13 +387,17 @@ def _params_event_ndims(cls):

def _batch_shape_tensor(self):
return tf.broadcast_dynamic_shape(
tf.shape(input=self.loc), tf.shape(input=self._slack))[:-1]
tf.shape(self.loc),
tf.broadcast_dynamic_shape(tf.shape(self.atol),
tf.shape(self.rtol)))[:-1]

def _batch_shape(self):
return tf.broadcast_static_shape(self.loc.shape, self._slack.shape)[:-1]
return tf.broadcast_static_shape(
self.loc.shape,
tf.broadcast_static_shape(self.atol.shape, self.rtol.shape))[:-1]

def _event_shape_tensor(self):
return tf.shape(input=self.loc)[-1:]
return tf.shape(self.loc)[-1:]

def _event_shape(self):
return self.loc.shape[-1:]
Expand All @@ -391,16 +407,17 @@ def _prob(self, x):
is_vector_check = assert_util.assert_rank_at_least(x, 1)
right_vec_space_check = assert_util.assert_equal(
self.event_shape_tensor(),
tf.gather(tf.shape(input=x),
tf.gather(tf.shape(x),
tf.rank(x) - 1),
message="Argument 'x' not defined in the same space R^k as this distribution"
)
with tf.control_dependencies([is_vector_check]):
with tf.control_dependencies([right_vec_space_check]):
x = tf.identity(x)
loc = tf.identity(self.loc)
return tf.cast(
tf.reduce_all(
input_tensor=tf.abs(x - self.loc) <= self._slack, axis=-1),
input_tensor=tf.abs(x - loc) <= self._slack(loc), axis=-1),
dtype=self.dtype)


Expand Down
79 changes: 69 additions & 10 deletions tensorflow_probability/python/distributions/deterministic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@

# Dependency imports
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
from tensorflow_probability.python.internal import test_case
from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top

rng = np.random.RandomState(0)
tfd = tfp.distributions


@test_util.run_all_in_graph_and_eager_modes
class DeterministicTest(tf.test.TestCase):
class DeterministicTest(test_case.TestCase):

def testShape(self):
loc = rng.rand(2, 3, 4)
Expand Down Expand Up @@ -178,11 +180,11 @@ def testSampleWithBatchAtol(self):
self.evaluate(sample))

def testSampleDynamicWithBatchDims(self):
loc = tf.compat.v1.placeholder_with_default(input=[0., 0], shape=[2])
loc = tf1.placeholder_with_default(input=[0., 0], shape=[2])

deterministic = tfd.Deterministic(loc)
for sample_shape_ in [(), (4,)]:
sample_shape = tf.compat.v1.placeholder_with_default(
sample_shape = tf1.placeholder_with_default(
input=np.array(sample_shape_, dtype=np.int32), shape=None)
sample_ = self.evaluate(deterministic.sample(sample_shape))
self.assertAllClose(
Expand Down Expand Up @@ -220,8 +222,37 @@ def testDeterministicGammaKL(self):
expected_kl_, actual_kl_ = self.evaluate([expected_kl, actual_kl])
self.assertAllEqual(expected_kl_, actual_kl_)

def testVariableGradients(self):
loc = tf.Variable(1.)
deterministic = tfd.Deterministic(loc=loc)
with tf.GradientTape() as tape:
s = deterministic.sample()
g = tape.gradient(s, deterministic.trainable_variables)
self.assertLen(g, 1)
self.assertAllNotNone(g)

def testVariableAssertions(self):
atol = tf.Variable(0.1)
rtol = tf.Variable(0.1)
deterministic = tfd.Deterministic(
loc=0.1, atol=atol, rtol=rtol, validate_args=True)

self.evaluate(tf1.global_variables_initializer())
self.evaluate(deterministic.log_prob(1.))

self.evaluate(atol.assign(-1.))
with self.assertRaisesRegexp((ValueError, tf.errors.InvalidArgumentError),
"Condition x >= 0"):
self.evaluate(deterministic.log_prob(1.))

self.evaluate(atol.assign(0.1))
self.evaluate(rtol.assign(-1.))
with self.assertRaisesRegexp((ValueError, tf.errors.InvalidArgumentError),
"Condition x >= 0"):
self.evaluate(deterministic.log_prob(1.))


class VectorDeterministicTest(tf.test.TestCase):
class VectorDeterministicTest(test_case.TestCase):

def testParamBroadcasts(self):
loc = rng.rand(2, 1, 4)
Expand All @@ -246,7 +277,7 @@ def testShape(self):
self.assertEqual(deterministic.event_shape, tf.TensorShape([4]))

def testShapeUknown(self):
loc = tf.compat.v1.placeholder_with_default(np.float32([0]), shape=[None])
loc = tf1.placeholder_with_default(np.float32([0]), shape=[None])
deterministic = tfd.VectorDeterministic(loc)
self.assertAllEqual(deterministic.event_shape_tensor().shape, [1])

Expand Down Expand Up @@ -344,12 +375,11 @@ def testSampleWithBatchDims(self):
self.evaluate(sample))

def testSampleDynamicWithBatchDims(self):
loc = tf.compat.v1.placeholder_with_default(
input=[[0.], [0.]], shape=[2, 1])
loc = tf1.placeholder_with_default(input=[[0.], [0.]], shape=[2, 1])

deterministic = tfd.VectorDeterministic(loc)
for sample_shape_ in [(), (4,)]:
sample_shape = tf.compat.v1.placeholder_with_default(
sample_shape = tf1.placeholder_with_default(
input=np.array(sample_shape_, dtype=np.int32), shape=None)
sample_ = self.evaluate(deterministic.sample(sample_shape))
self.assertAllClose(
Expand Down Expand Up @@ -389,6 +419,35 @@ def testVectorDeterministicMultivariateNormalDiagKL(self):
expected_kl_, actual_kl_ = self.evaluate([expected_kl, actual_kl])
self.assertAllEqual(expected_kl_, actual_kl_)

def testVariableGradients(self):
loc = tf.Variable([1., 2.])
deterministic = tfd.VectorDeterministic(loc=loc)
with tf.GradientTape() as tape:
s = deterministic.sample()
g = tape.gradient(s, deterministic.trainable_variables)
self.assertLen(g, 1)
self.assertAllNotNone(g)

def testVariableAssertions(self):
atol = tf.Variable(0.1)
rtol = tf.Variable(0.1)
deterministic = tfd.VectorDeterministic(
loc=[0.1], atol=atol, rtol=rtol, validate_args=True)

self.evaluate(tf1.global_variables_initializer())
self.evaluate(deterministic.log_prob([1.]))

self.evaluate(atol.assign(-1.))
with self.assertRaisesRegexp((ValueError, tf.errors.InvalidArgumentError),
"Condition x >= 0"):
self.evaluate(deterministic.log_prob([1.]))

self.evaluate(atol.assign(0.1))
self.evaluate(rtol.assign(-1.))
with self.assertRaisesRegexp((ValueError, tf.errors.InvalidArgumentError),
"Condition x >= 0"):
self.evaluate(deterministic.log_prob([1.]))


if __name__ == "__main__":
tf.test.main()

0 comments on commit 3ea6662

Please sign in to comment.