Skip to content

Commit

Permalink
Add test_backprop_shared_parent
Browse files Browse the repository at this point in the history
test_backprop_shared_parent makes sure that the backprop function adds gradients if there are contributions from multiple child nodes.
  • Loading branch information
alex-garvey authored Aug 9, 2023
1 parent e0253aa commit f4df18c
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions chapter0_fundamentals/exercises/part5_backprop/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def test_backprop(Tensor):
assert np.allclose(a.grad.array, 1 / b.array / a.array)
print("All tests in `test_backprop` passed!")


def test_backprop_branching(Tensor):
a = Tensor([1, 2, 3], requires_grad=True)
b = Tensor([1, 2, 3], requires_grad=True)
Expand All @@ -243,7 +242,6 @@ def test_backprop_branching(Tensor):
assert np.allclose(b.grad.array, a.array)
print("All tests in `test_backprop_branching` passed!")


def test_backprop_requires_grad_false(Tensor):
a = Tensor([1, 2, 3], requires_grad=True)
b = Tensor([1, 2, 3], requires_grad=False)
Expand All @@ -253,7 +251,6 @@ def test_backprop_requires_grad_false(Tensor):
assert b.grad is None
print("All tests in `test_backprop_requires_grad_false` passed!")


def test_backprop_float_arg(Tensor):
a = Tensor([1, 2, 3], requires_grad=True)
b = 2
Expand All @@ -267,6 +264,19 @@ def test_backprop_float_arg(Tensor):
assert np.allclose(a.grad.array, np.array([4.0, 4.0, 4.0]))
print("All tests in `test_backprop_float_arg` passed!")

def test_backprop_shared_parent(Tensor):
a = 2
b = Tensor([1, 2, 3], requires_grad=True)
c = 3
d = a * b
e = b * c
f = d * e
f.backward(end_grad=np.array([1.0, 1.0, 1.0]))
assert f.grad is None
assert b.grad is not None
assert np.allclose(b.grad.array, np.array([12.0, 24.0, 36.0])), "Multiple nodes may have the same parent."
print("All tests in `test_backprop_shared_parent` passed!")

def test_negative_back(Tensor):
a = Tensor([-1.0, 0.0, 1.0], requires_grad=True)
b = -a
Expand Down

0 comments on commit f4df18c

Please sign in to comment.