-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtest_builder.py
88 lines (63 loc) · 2.3 KB
/
test_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# pylint: disable=missing-docstring
import pytest
from nengo.exceptions import BuildError
from nengo_dl.builder import Builder, NengoModel, OpBuilder
from nengo_dl.tests import dummies
from nengo_dl.utils import NullProgressBar
def test_custom_builder():
# pylint: disable=unused-variable
class TestOp:
sets = None
incs = None
reads = None
updates = None
ops = (TestOp(),)
progress = NullProgressBar()
# error if no builder registered
with pytest.raises(BuildError, match="No registered builder"):
Builder([ops])
# warning if builder doesn't subclass OpBuilder
with pytest.warns(UserWarning):
@Builder.register(TestOp)
class TestOpBuilder0:
pass
# warning when overwriting a registered builder
with pytest.warns(UserWarning):
@Builder.register(TestOp)
class TestOpBuilder(OpBuilder):
pre_built = False
post_built = False
def build_pre(self, signals, config):
super().build_pre(signals, config)
self.pre_built = True
def build_step(self, signals):
assert self.pre_built
assert not self.post_built
return 0, 1
def build_post(self, signals):
self.post_built = True
builder = Builder([ops])
builder.build_pre(signals=None, config=None, progress=progress)
result = builder.build_step(signals=None, progress=progress)
assert len(result) == 2
assert result[0] == 0
assert result[1] == 1
builder.build_post(signals=None, progress=progress)
assert builder.op_builds[ops].post_built
# error if builder doesn't define build_step
@Builder.register(TestOp)
class TestOpBuilder2(OpBuilder):
pass
builder = Builder([ops])
builder.build_pre(signals=None, config=None, progress=progress)
with pytest.raises(BuildError, match="must implement a `build_step` function"):
builder.build_step(signals=None, progress=progress)
@pytest.mark.parametrize("fail_fast", (True, False))
def test_custom_model(fail_fast):
model = NengoModel(fail_fast=fail_fast)
try:
model.add_op(dummies.Op())
except NotImplementedError:
assert fail_fast
else:
assert not fail_fast