Skip to content

Commit

Permalink
Use concatenate in np.gradient VJP for order>1
Browse files Browse the repository at this point in the history
Alter tests accordingly
  • Loading branch information
Clemens Schmid committed Apr 8, 2019
1 parent 22b71e3 commit 73f1df0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 53 deletions.
76 changes: 29 additions & 47 deletions autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def grad_diff(ans, a, n=1, axis=-1):

def undiff(g):
if g.shape[axis] > 0:
return anp.concatenate((-g[sl1], -anp.diff(g, axis=axis), g[sl2]), axis=axis)
return anp.concatenate(
(-g[sl1], -anp.diff(g, axis=axis), g[sl2]),
axis=axis)
shape = list(ans_shape)
shape[axis] = 1
return anp.zeros(shape)
Expand All @@ -178,62 +180,42 @@ def helper(g, n):

def grad_gradient(ans, x, axis=None):
if axis is None:
if ans.ndim == x.ndim:
# 1D case (no axis but same shape)
axis = [0]
else:
# gradient along all axes
axis = range(x.ndim)

axis = range(x.ndim)
elif type(axis) is int:
# not 1D but only along 1 axis
axis = [axis]

else:
# axis should be already iterable
axis = list(axis)

# makes negative indices in axis positive
# (needed for axis swap later)
for i, a in enumerate(axis):
if a < 0: axis[i] = x.ndim + a
# make all indices positive
axis = [(x.ndim + a) % x.ndim for a in axis]

if len(axis) == 1:
# along one axis only, can be 1D or nD array
def vjp(g):
# Jacobian of np.gradient is mainly negative gradient
out = (-1.) * onp.gradient(g, axis=axis[0])
x_dtype = x.dtype
x_shape = x.shape
nd = x.ndim

# shift gradient axis to the front
out_swap = out.swapaxes(0, axis[0])
g_swap = g.swapaxes(0, axis[0])
def vjp(g):
if g.ndim == nd:
# add axis if gradient was along one axis only
g = g[anp.newaxis]

# border handling
out_swap[0] = -g_swap[0] - 0.5 * g_swap[1]
out_swap[1] = g_swap[0] - 0.5 * g_swap[2]
out_swap[-2] = 0.5 * g_swap[-3] - g_swap[-1]
out_swap[-1] = 0.5 * g_swap[-2] + g_swap[-1]
# accumulate gradient
out = anp.zeros(x_shape, dtype=x_dtype)

return out
for i, a in enumerate(axis):
# swap gradient axis to the front
g_swap = anp.swapaxes(g[i], 0, a)[:, anp.newaxis]

else:
# nd case
def vjp(g):
out = onp.zeros_like(g)
for i, k in enumerate(axis):
# Jacobian of np.gradient is mainly negative gradient
out[i] = (-1.) * onp.array(onp.gradient(g[i], axis=k))

# border handling
out_swap = out.swapaxes(1, 1+k)
g_swap = g.swapaxes(1, 1+k)

out_swap[i, 0] = -g_swap[i, 0] - 0.5 * g_swap[i, 1]
out_swap[i, 1] = g_swap[i, 0] - 0.5 * g_swap[i, 2]
out_swap[i, -2] = 0.5 * g_swap[i, -3] - g_swap[i, -1]
out_swap[i, -1] = 0.5 * g_swap[i, -2] + g_swap[i, -1]

return onp.sum(out, axis=0)
out_axis = anp.concatenate((
-g_swap[0] - 0.5 * g_swap[1],
g_swap[0] - 0.5 * g_swap[2],
(-1.) * anp.gradient(g_swap, axis=0)[2:-2, 0],
0.5 * g_swap[-3] - g_swap[-1],
0.5 * g_swap[-2] + g_swap[-1],
), axis=0)

out = out + anp.swapaxes(out_axis, 0, a)

return out

return vjp

Expand Down
8 changes: 4 additions & 4 deletions tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,9 @@ def f(x): return np.sum(np.sin(x.astype('float64')))
assert grad(f)(x).dtype == np.dtype('float32')

def test_gradient():
check_grads(np.gradient, 0, order=1)(npr.randn(10))
check_grads(np.gradient, 0, order=1)(npr.randn(10, 10))
check_grads(np.gradient, 0, order=1)(npr.randn(10, 10, 10))
check_grads(np.gradient, 0)(npr.randn(10))
check_grads(np.gradient, 0)(npr.randn(10, 10))
check_grads(np.gradient, 0)(npr.randn(10, 10, 10))

for a in [None, 0, 1, -1, (0, 1), (0, -1)]:
check_grads(np.gradient, 0, order=1)(npr.randn(10, 10, 10), axis=a)
check_grads(np.gradient, 0)(npr.randn(10, 10, 10), axis=a)
4 changes: 2 additions & 2 deletions tests/test_systematic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def test_diff():
combo_check(np.diff, [0])([R(1,1), R(3,1)], axis=[1])

def test_gradient():
combo_check(np.gradient, [0], order=1)([R(5,5), R(5,5,5)], axis=[0,1,-1])
combo_check(np.gradient, [0], order=1)([R(5,5,5)], axis=[(0, 1), (0, -1)])
combo_check(np.gradient, [0])([R(5,5), R(5,5,5)], axis=[0,1,-1])
combo_check(np.gradient, [0])([R(5,5,5)], axis=[(0, 1), (0, -1)])

def test_tile():
combo_check(np.tile, [0])([R(2,1,3,1)], reps=[(1, 4, 1, 2)])
Expand Down

0 comments on commit 73f1df0

Please sign in to comment.