forked from HIPS/autograd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_logic.py
51 lines (44 loc) · 1.53 KB
/
test_logic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from __future__ import division
from contextlib import contextmanager
import pytest
import warnings
import autograd.numpy as np
from autograd import grad, deriv
from autograd.extend import primitive
from autograd.test_util import check_grads
from autograd.core import primitive_vjps
def test_assert():
# from https://github.com/HIPS/autograd/issues/43
def fun(x):
assert np.allclose(x, (x*3.0)/3.0)
return np.sum(x)
check_grads(fun)(np.array([1.0, 2.0, 3.0]))
def test_nograd():
# we want this to raise non-differentiability error
fun = lambda x: np.allclose(x, (x*3.0)/3.0)
with pytest.raises(TypeError):
with warnings.catch_warnings(record=True) as w:
grad(fun)(np.array([1., 2., 3.]))
def test_no_vjp_def():
fun = primitive(lambda x: 2. * x)
with pytest.raises(NotImplementedError):
grad(fun)(1.)
def test_no_jvp_def():
fun = primitive(lambda x: 2. * x)
with pytest.raises(NotImplementedError):
deriv(fun)(1.)
def test_falseyness():
fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
check_grads(fun)(2.)
check_grads(fun)(2. + 1j)
def test_unimplemented_falseyness():
@contextmanager
def remove_grad_definitions(fun):
vjpmaker = primitive_vjps.pop(fun, None)
yield
if vjpmaker:
primitive_vjps[fun] = vjpmaker
with remove_grad_definitions(np.iscomplex):
fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
check_grads(fun)(5.)
check_grads(fun)(2. + 1j)