Skip to content

Commit

Permalink
Merge pull request #1 from shinh/fake-asfunc-grad
Browse files Browse the repository at this point in the history
Fix WrappedFunctionNode.backward
  • Loading branch information
disktnk authored Oct 16, 2019
2 parents bb3b65c + 36ca673 commit 8a6ea46
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
18 changes: 12 additions & 6 deletions onnx_chainer/replace_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,24 @@ class WrappedFunctionNode(chainer.FunctionNode):
func (func): the target function
args (list): args for the function
kwargs (dict): kwargs for the function
arg_vars (list): list of `chainer.Variable`s in `args` and `kwargs`
attributes (list): parameters to be set node's attributes
"""

def __init__(self, name, func, args, kwargs, attributes=None):
def __init__(self, name, func, args, kwargs, arg_vars, attributes=None):
self.custom_function_node_name = name
self.func = func
self.args = args
self.kwargs = kwargs
self.arg_vars = arg_vars
self.internal_results = None

if attributes is not None:
for k, v in attributes.items():
setattr(self, k, v)

def forward(self, xs):
assert len(xs) == len(self.arg_vars)
self.xs = xs
results = self.func(*self.args, **self.kwargs)
if isinstance(results, (tuple, list)):
Expand All @@ -50,10 +53,13 @@ def forward(self, xs):
def backward(self, target_input_indexes, grad_outputs):
if self.internal_results is None:
raise ValueError(
'the target function does not support backward, propagation is'
'failed')
chainer.backward(self.internal_results, grad_outputs)
return super().backward(target_input_indexes, grad_outputs)
'the target function does not support backward, propagation '
'is failed')
grad_inputs = chainer.grad(self.internal_results, self.arg_vars,
grad_outputs=grad_outputs)
assert len(self.arg_vars) == len(grad_inputs)
return tuple(grad_input if i in target_input_indexes else None
for i, grad_input in enumerate(grad_inputs))


def fake_as_funcnode(alt_func, name, rename_attributes=None):
Expand Down Expand Up @@ -156,7 +162,7 @@ def expand_args(args_iter):
'{}'.format(name))

wrapped = WrappedFunctionNode(
name, alt_func, args, kwargs, attributes=attributes)
name, alt_func, args, kwargs, inputs, attributes=attributes)
ret = wrapped.apply(inputs)
if len(ret) > 1:
return ret
Expand Down
60 changes: 60 additions & 0 deletions tests/onnx_chainer_tests/test_replace_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,66 @@ def load_tensor(path):
model, x, output_grad=True, custom_model_test_func=gradient_check)


class TestReplaceFuncBackward(ONNXModelTest):

def _test_replace_func(self, fn, xs, set_grad=False):
def make_list(v):
if isinstance(v, (list, tuple)):
return list(v)
else:
return [v]

xvs = [x for x in xs if isinstance(x, chainer.Variable)]
rfn = as_funcnode('fn')(fn)
eys = make_list(fn(*xs))
egxs = chainer.grad(eys, xvs, set_grad=set_grad)
ays = make_list(rfn(*xs))
agxs = chainer.grad(ays, xvs, set_grad=set_grad)
assert len(eys) == len(ays)
for ay, ey in zip(ays, eys):
np.testing.assert_allclose(ay.array, ey.array)
assert len(egxs) == len(agxs)
for agx, egx in zip(agxs, egxs):
if egx is None:
assert egx is None
else:
np.testing.assert_allclose(agx.array, egx.array)

def test_backward_simple(self):
self._test_replace_func(lambda a, b: a * b,
[chainer.Variable(np.array(2.3)),
chainer.Variable(np.array(4.2))])

def test_backward_partially_differentiable(self):
self._test_replace_func(lambda a, b: a * b.array,
[chainer.Variable(np.array(2.3)),
chainer.Variable(np.array(4.2))])

def test_backward_multi_outputs(self):
self._test_replace_func(lambda a, b, c: (a * b, a / b, a * b * c),
[chainer.Variable(np.array(2.3)),
chainer.Variable(np.array(4.2)),
5])

def test_backward_no_side_effect(self):
a = chainer.Variable(np.array(2.3))
b = chainer.Variable(np.array(4.2))
x0 = a * b
x1 = chainer.Variable(np.array(3.7))
self._test_replace_func(lambda a, b: a * b, [x0, x1])
# No side-effect to `grad`.
assert x0.grad is None
assert x1.grad is None
assert a.grad is None
assert b.grad is None
# Gradient computation must stop at `x0` and `x1`.
self._test_replace_func(lambda a, b: a * b, [x0, x1], set_grad=True)
assert x0.grad is not None
assert x1.grad is not None
assert a.grad is None
assert b.grad is None


@testing.parameterize(
{'func_kind': 'list', 'in_shape': (2, 3, 4), 'op_type': 'Add'},
{'func_kind': 'list_kwargs', 'in_shape': (2, 3, 4), 'op_type': 'Add'},
Expand Down

0 comments on commit 8a6ea46

Please sign in to comment.