forked from stanfordnlp/dspy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_bootstrap.py
180 lines (139 loc) · 5.48 KB
/
test_bootstrap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import pytest
import dspy
from dspy.predict import Predict
from dspy.utils.dummies import DummyLM
from dspy import Example
from dspy.teleprompt import BootstrapFewShot
import textwrap
# Define a simple metric function for testing
def simple_metric(example, prediction, trace=None):
# Simplified metric for testing: true if prediction matches expected output
return example.output == prediction.output
examples = [
Example(input="What is the color of the sky?", output="blue").with_inputs("input"),
Example(
input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!"
),
]
trainset = [examples[0]]
valset = [examples[1]]
def test_bootstrap_initialization():
# Initialize BootstrapFewShot with a dummy metric and minimal setup
bootstrap = BootstrapFewShot(
metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1
)
assert bootstrap.metric == simple_metric, "Metric not correctly initialized"
class SimpleModule(dspy.Module):
def __init__(self, signature):
super().__init__()
self.predictor = Predict(signature)
def forward(self, **kwargs):
return self.predictor(**kwargs)
def test_compile_with_predict_instances():
# Create Predict instances for student and teacher
# Note that dspy.Predict is not itself a module, so we can't use it directly here
student = SimpleModule("input -> output")
teacher = SimpleModule("input -> output")
lm = DummyLM(["Initial thoughts", "Finish[blue]"])
dspy.settings.configure(lm=lm)
# Initialize BootstrapFewShot and compile the student
bootstrap = BootstrapFewShot(
metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1
)
compiled_student = bootstrap.compile(
student, teacher=teacher, trainset=trainset
)
assert compiled_student is not None, "Failed to compile student"
assert (
hasattr(compiled_student, "_compiled") and compiled_student._compiled
), "Student compilation flag not set"
def test_bootstrap_effectiveness():
# This test verifies if the bootstrapping process improves the student's predictions
student = SimpleModule("input -> output")
teacher = SimpleModule("input -> output")
lm = DummyLM(["blue", "Ring-ding-ding-ding-dingeringeding!"], follow_examples=True)
dspy.settings.configure(lm=lm, trace=[])
bootstrap = BootstrapFewShot(
metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1
)
compiled_student = bootstrap.compile(
student, teacher=teacher, trainset=trainset
)
# Check that the compiled student has the correct demos
assert len(compiled_student.predictor.demos) == 1
assert compiled_student.predictor.demos[0].input == trainset[0].input
assert compiled_student.predictor.demos[0].output == trainset[0].output
# Test the compiled student's prediction.
# We are using a DummyLM with follow_examples=True, which means that
# even though it would normally reply with "Ring-ding-ding-ding-dingeringeding!"
# on the second output, if it seems an example that perfectly matches the
# prompt, it will use that instead. That is why we expect "blue" here.
prediction = compiled_student(input=trainset[0].input)
assert prediction.output == trainset[0].output
# For debugging
print("Convo")
print(lm.get_convo(-1))
assert lm.get_convo(-1) == textwrap.dedent(
"""\
Given the fields `input`, produce the fields `output`.
---
Follow the following format.
Input: ${input}
Output: ${output}
---
Input: What is the color of the sky?
Output: blue
---
Input: What is the color of the sky?
Output: blue"""
)
def test_error_handling_during_bootstrap():
"""
Test to verify error handling during the bootstrapping process
"""
class BuggyModule(dspy.Module):
def __init__(self, signature):
super().__init__()
self.predictor = Predict(signature)
def forward(self, **kwargs):
raise RuntimeError("Simulated error")
student = SimpleModule("input -> output")
teacher = BuggyModule("input -> output")
# Setup DummyLM to simulate an error scenario
lm = DummyLM(
[
"Initial thoughts", # Simulate initial teacher's prediction
]
)
dspy.settings.configure(lm=lm)
bootstrap = BootstrapFewShot(
metric=simple_metric,
max_bootstrapped_demos=1,
max_labeled_demos=1,
max_errors=1,
)
with pytest.raises(RuntimeError, match="Simulated error"):
bootstrap.compile(student, teacher=teacher, trainset=trainset)
def test_validation_set_usage():
"""
Test to ensure the validation set is correctly used during bootstrapping
"""
student = SimpleModule("input -> output")
teacher = SimpleModule("input -> output")
lm = DummyLM(
[
"Initial thoughts",
"Finish[blue]", # Expected output for both training and validation
]
)
dspy.settings.configure(lm=lm)
bootstrap = BootstrapFewShot(
metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1
)
compiled_student = bootstrap.compile(
student, teacher=teacher, trainset=trainset
)
# Check that validation examples are part of student's demos after compilation
assert len(compiled_student.predictor.demos) >= len(
valset
), "Validation set not used in compiled student demos"