Skip to content

Commit

Permalink
Merge pull request tensorflow#933 from emilyfertig/r0.10
Browse files Browse the repository at this point in the history
R0.10
  • Loading branch information
emilyfertig authored May 14, 2020
2 parents 3244d86 + 73ce8fa commit f051e03
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 5 deletions.
4 changes: 4 additions & 0 deletions tensorflow_probability/python/bijectors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1521,12 +1521,16 @@ multi_substrate_py_test(
name = "softplus_test",
size = "small",
srcs = ["softplus_test.py"],
jax_size = "medium",
deps = [
":bijector_test_util",
":bijectors",
# absl/testing:parameterized dep,
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:test_util",
"//tensorflow_probability/python/math",
# tensorflow/compiler/jit dep,
],
)

Expand Down
29 changes: 27 additions & 2 deletions tensorflow_probability/python/bijectors/softplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,31 @@
]


JAX_MODE = False # Overwritten by rewrite script.


# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
_stable_grad_softplus = tf.nn.softplus
else:

@tf.custom_gradient
def _stable_grad_softplus(x):
"""A (more) numerically stable softplus than `tf.nn.softplus`."""
x = tf.convert_to_tensor(x)
if x.dtype == tf.float64:
cutoff = -20
else:
cutoff = -9

y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))

def grad_fn(dy):
return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))

return y, grad_fn


class Softplus(bijector.Bijector):
"""Bijector which computes `Y = g(X) = Log[1 + exp(X)]`.
Expand Down Expand Up @@ -101,9 +126,9 @@ def _is_increasing(cls):

def _forward(self, x):
if self.hinge_softness is None:
return tf.math.softplus(x)
return _stable_grad_softplus(x)
hinge_softness = tf.cast(self.hinge_softness, x.dtype)
return hinge_softness * tf.math.softplus(x / hinge_softness)
return hinge_softness * _stable_grad_softplus(x / hinge_softness)

def _inverse(self, y):
if self.hinge_softness is None:
Expand Down
21 changes: 21 additions & 0 deletions tensorflow_probability/python/bijectors/softplus_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

# Dependency imports

from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python import bijectors as tfb
from tensorflow_probability.python import math as tfp_math
from tensorflow_probability.python.bijectors import bijector_test_util
from tensorflow_probability.python.internal import test_util

Expand Down Expand Up @@ -149,6 +151,25 @@ def testVariableHingeSoftness(self):
with tf.control_dependencies([hinge_softness.assign(0.)]):
self.evaluate(b.forward(0.5))

@parameterized.named_parameters(
('32bitGraph', np.float32, False),
('64bitGraph', np.float64, False),
('32bitXLA', np.float32, True),
('64bitXLA', np.float64, True),
)
@test_util.numpy_disable_gradient_test
def testLeftTailGrad(self, dtype, do_compile):
x = np.linspace(-50., -8., 1000).astype(dtype)

@tf.function(autograph=False, experimental_compile=do_compile)
def fn(x):
return tf.math.log(tfb.Softplus().forward(x))

_, grad = tfp_math.value_and_gradient(fn, x)

true_grad = 1 / (1 + np.exp(-x)) / np.log1p(np.exp(x))
self.assertAllClose(true_grad, self.evaluate(grad), atol=1e-3)


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -287,5 +287,7 @@ def _convert_to_dict(x):
if isinstance(x, collections.OrderedDict):
return x
if hasattr(x, '_asdict'):
return x._asdict()
# Wrap with `OrderedDict` to indicate that namedtuples have a well-defined
# order (by default, they convert to just `dict` in Python 3.8+).
return collections.OrderedDict(x._asdict())
return dict(x)
4 changes: 2 additions & 2 deletions tensorflow_probability/python/layers/distribution_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pickle

# Dependency imports
from cloudpickle import CloudPickler
from cloudpickle.cloudpickle import CloudPickler
import numpy as np
import six
import tensorflow.compat.v2 as tf
Expand All @@ -47,7 +47,7 @@
from tensorflow_probability.python.distributions import variational_gaussian_process as variational_gaussian_process_lib
from tensorflow_probability.python.internal import distribution_util as dist_util
from tensorflow_probability.python.layers.internal import distribution_tensor_coercible as dtc
from tensorflow_probability.python.layers.internal import tensor_tuple as tensor_tuple
from tensorflow_probability.python.layers.internal import tensor_tuple
from tensorflow.python.keras.utils import tf_utils as keras_tf_utils # pylint: disable=g-direct-tensorflow-import


Expand Down

0 comments on commit f051e03

Please sign in to comment.