Skip to content

Commit

Permalink
[Oryx] Fix bug in reshape ILDJ rule
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 343545188
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Nov 20, 2020
1 parent 794a96c commit 3f0bae3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
8 changes: 8 additions & 0 deletions spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ def f(x, y):
onp.testing.assert_allclose(y, np.ones(2))
onp.testing.assert_allclose(ildj_, 0., atol=1e-6, rtol=1e-6)

def test_inverse_of_reshape(self):
def f(x):
return np.reshape(x, (4,))
f_inv = core.inverse_and_ildj(f, np.ones((2, 2)))
x, ildj_ = f_inv(np.ones(4))
onp.testing.assert_allclose(x, np.ones((2, 2)))
onp.testing.assert_allclose(ildj_, 0.)

def test_sigmoid_ildj(self):
def naive_sigmoid(x):
# This is the default JAX implementation of sigmoid.
Expand Down
3 changes: 1 addition & 2 deletions spinoffs/oryx/oryx/core/interpreters/inverse/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,8 @@ def reshape_ildj(incells, outcells, **params):
))], None
elif outcell.top() and not incell.top():
val = outcell.val
ndslice = NDSlice.new(np.reshape(val, incell.aval.shape)) # pytype: disable=missing-parameter
new_incells = [
InverseAndILDJ(incell.aval, [ndslice])
InverseAndILDJ.new(np.reshape(val, incell.aval.shape))
]
return new_incells, outcells, None
return incells, outcells, None
Expand Down

0 comments on commit 3f0bae3

Please sign in to comment.