Skip to content

Commit

Permalink
Implements tf.nn.moments in the TFP JAX backend.
Browse files Browse the repository at this point in the history
Based on the implementation in TensorFlow -- see `tensorflow/python/ops/nn_impl.py`.

PiperOrigin-RevId: 286388988
  • Loading branch information
jburnim authored and tensorflower-gardener committed Dec 19, 2019
1 parent c9838bf commit 6a0386c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/internal/backend/numpy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ py_library(
srcs = ["nn.py"],
deps = [
":_utils",
":numpy_array",
":numpy_math",
":ops",
# numpy dep,
# tensorflow dep,
],
Expand Down
20 changes: 20 additions & 0 deletions tensorflow_probability/python/internal/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@
from tensorflow_probability.python.internal.backend.numpy.numpy_math import l2_normalize
from tensorflow_probability.python.internal.backend.numpy.numpy_math import log_softmax
from tensorflow_probability.python.internal.backend.numpy.numpy_math import reduce_logsumexp
from tensorflow_probability.python.internal.backend.numpy.numpy_math import reduce_mean
from tensorflow_probability.python.internal.backend.numpy.numpy_math import softmax
from tensorflow_probability.python.internal.backend.numpy.numpy_math import softplus
from tensorflow_probability.python.internal.backend.numpy.numpy_math import squared_difference
from tensorflow_probability.python.internal.backend.numpy.numpy_math import top_k
from tensorflow_probability.python.internal.backend.numpy.ops import stop_gradient


__all__ = [
'l2_normalize',
'log_softmax',
'moments',
'relu',
'softmax',
'softplus',
Expand Down Expand Up @@ -81,6 +85,22 @@ def _sparse_softmax_cross_entropy_with_logits( # pylint: disable=invalid-name,u
l2_normalize)


def _moments(x, axes, shift=None, keepdims=False, name=None): # pylint: disable=unused-argument
# NOTE: If x.dtype is float16, we may want to compute in float32.
mean = reduce_mean(x, axis=axes, keepdims=True)
# NOTE: The gradient backpropagated to the mean from the variance calcuation
# is zero, so we can safely use `stop_gradient(mean)` for efficiency.
variance = reduce_mean(squared_difference(x, stop_gradient(mean)),
axis=axes, keepdims=keepdims)
if not keepdims:
mean = numpy_array.squeeze(mean, axes)
return (mean, variance)

moments = utils.copy_docstring(
tf.nn.moments,
_moments)


relu = utils.copy_docstring(
tf.nn.relu,
lambda features, name=None: np.max(features, 0))
Expand Down

0 comments on commit 6a0386c

Please sign in to comment.