Skip to content

Commit

Permalink
Fix gradient for softplus at 0 in the TFP JAX backend.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 286582193
  • Loading branch information
jburnim authored and tensorflower-gardener committed Dec 20, 2019
1 parent 9e8e380 commit b0d44b1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/internal/backend/numpy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ py_library(
srcs = ["random_generators.py"],
deps = [
":_utils",
":numpy_array",
":numpy_math",
":ops",
# numpy dep,
# tensorflow dep,
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from tensorflow_probability.python.internal.backend.numpy import _utils as utils
from tensorflow_probability.python.internal.backend.numpy.numpy_array import _reverse
from tensorflow_probability.python.internal.backend.numpy.ops import _custom_gradient

scipy_special = utils.try_import('scipy.special')

Expand Down Expand Up @@ -730,9 +731,17 @@ def _multiply_no_nan(x, y, name=None): # pylint: disable=unused-argument
tf.math.softmax,
_softmax)


@_custom_gradient
def _softplus(x, name=None): # pylint: disable=unused-argument
def grad(dy):
return dy * scipy_special.expit(x)
# TODO(b/146563881): Investigate improving numerical accuracy here.
return np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0.), grad

softplus = utils.copy_docstring(
tf.math.softplus,
lambda x, name=None: np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0.))
_softplus)

softsign = utils.copy_docstring(
tf.math.softsign,
Expand Down

0 comments on commit b0d44b1

Please sign in to comment.