Skip to content

Commit

Permalink
Add output_gradients argument to value_and_gradient.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 262958665
  • Loading branch information
parsiad authored and tensorflower-gardener committed Aug 12, 2019
1 parent efdadf5 commit 0666f73
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
21 changes: 15 additions & 6 deletions tensorflow_probability/python/math/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
]


def value_and_gradient(f, xs, use_gradient_tape=False, name=None):
def value_and_gradient(f,
xs,
output_gradients=None,
use_gradient_tape=False,
name=None):
"""Computes `f(*xs)` and its gradients wrt to `*xs`.
Args:
Expand All @@ -37,10 +41,15 @@ def value_and_gradient(f, xs, use_gradient_tape=False, name=None):
a single scalar. If desired, the tensors can be elementwise multiplied by
the tensors passed as the `dy` keyword argument to the returned gradient
function.
xs: Python list of parameters of f for which to differentiate. (Can also
xs: Python list of parameters of `f` for which to differentiate. (Can also
be single `Tensor`.)
use_gradient_tape: Python `bool` indicating that `tf.GradientTape`
should be used regardless of `tf.executing_eagerly()` status.
output_gradients: A `Tensor` or list of `Tensor`s the same size as the
result `ys = f(*xs)` and holding the gradients computed for each `y` in
`ys`. This argument is forwarded to the underlying gradient implementation
(i.e., either the `grad_ys` argument of `tf.gradients` or the
`output_gradients` argument of `tf.GradientTape.gradient`).
use_gradient_tape: Python `bool` indicating that `tf.GradientTape` should be
used regardless of `tf.executing_eagerly()` status.
Default value: `False`.
name: Python `str` name prefixed to ops created by this function.
Default value: `None` (i.e., `'value_and_gradient'`).
Expand All @@ -62,10 +71,10 @@ def value_and_gradient(f, xs, use_gradient_tape=False, name=None):
for x in xs:
tape.watch(x)
y = f(*xs)
dydx = tape.gradient(y, xs)
dydx = tape.gradient(y, xs, output_gradients=output_gradients)
else:
y = f(*xs)
dydx = tf.gradients(ys=y, xs=xs)
dydx = tf.gradients(ys=y, xs=xs, grad_ys=output_gradients)
if not is_xs_list_like:
dydx = dydx[0]
return y, dydx
11 changes: 11 additions & 0 deletions tensorflow_probability/python/math/gradient_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def test_list(self):
self.assertAllClose(f(*args), y, atol=1e-6, rtol=1e-6)
self.assertAllClose(g(*args), dydx, atol=1e-6, rtol=1e-6)

def test_output_gradients(self):
jacobian = np.float32([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
f = lambda x: tf.squeeze(tf.matmul(jacobian, x[:, tf.newaxis]))
x = np.ones([3], dtype=np.float32)
output_gradients = np.float32([1., 2., 3.])
y, dydx = self.evaluate(
tfp.math.value_and_gradient(f, x, output_gradients=output_gradients))
self.assertAllClose(f(x), y, atol=1e-6, rtol=1e-6)
self.assertAllClose(
np.dot(output_gradients, jacobian), dydx, atol=1e-6, rtol=1e-6)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 0666f73

Please sign in to comment.