forked from HIPS/autograd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_core.py
43 lines (36 loc) · 1.61 KB
/
test_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""This file doesn't import the numpy wrapper, to check if core works
on basic operations even without numpy."""
from __future__ import absolute_import
import warnings
from autograd.core import make_vjp
from autograd.wrap_util import unary_to_nary
@unary_to_nary
def grad(fun, x):
vjp, _ = make_vjp(fun, x)
return vjp(1.0)
# Non-numpy gradient checking functions.
def nd(f, x, eps=1e-4):
return (f(x + eps/2) - f(x - eps/2)) / eps
def check_close(a, b, atol=1e-4, rtol=1e-4):
assert abs(a - b) < atol + rtol*abs(b), "Diffs are: {0}".format(a - b)
def check_binary_func(fun, independent=False):
with warnings.catch_warnings(record=independent) as w:
x, y = 0.7, 1.8
a = grad(fun)(x, y)
b = nd(lambda x: fun(x, y), x)
check_close(a, b)
a = grad(fun, 1)(x, y)
b = nd(lambda y: fun(x, y), y)
check_close(a, b)
def test_add(): check_binary_func(lambda x, y: x + y)
def test_sub(): check_binary_func(lambda x, y: x - y)
def test_div(): check_binary_func(lambda x, y: x / y)
def test_mul(): check_binary_func(lambda x, y: x * y)
def test_pow(): check_binary_func(lambda x, y: x ** y)
def test_mod(): check_binary_func(lambda x, y: x % y)
def test_eq(): check_binary_func(lambda x, y: x == y, independent=True)
def test_neq(): check_binary_func(lambda x, y: x != y, independent=True)
def test_leq(): check_binary_func(lambda x, y: x <= y, independent=True)
def test_geq(): check_binary_func(lambda x, y: x >= y, independent=True)
def test_lt(): check_binary_func(lambda x, y: x < y, independent=True)
def test_gt(): check_binary_func(lambda x, y: x > y, independent=True)