forked from stanfordnlp/dspy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
assert_hotpotqa.py
142 lines (106 loc) · 4.35 KB
/
assert_hotpotqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import dspy
from dsp.utils import deduplicate
from dspy.datasets import HotPotQA
from dspy.predict.retry import Retry
from dspy.teleprompt import BootstrapFewShot
# pipeline configs
turbo = dspy.OpenAI(model="gpt-3.5-turbo")
colbertv2_wiki17_abstracts = dspy.ColBERTv2(
url="http://20.102.90.50:2017/wiki17_abstracts"
)
dspy.settings.configure(lm=turbo, rm=colbertv2_wiki17_abstracts)
# load dataset
dataset = HotPotQA(
train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0
)
trainset = [x.with_inputs("question") for x in dataset.train]
devset = [x.with_inputs("question") for x in dataset.dev]
# signatures of dspy modules
class GenerateAnswer(dspy.Signature):
"""Answer questions with short factoid answers."""
context = dspy.InputField(desc="may contain relevant facts")
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")
class GenerateSearchQuery(dspy.Signature):
"""Write a simple search query that will help answer a complex question."""
context = dspy.InputField(desc="may contain relevant facts")
question = dspy.InputField()
query = dspy.OutputField()
failure_counts = {
"answer_exact_match": 0,
"answer_passage_match": 0,
"max_hop_length_exceeded": 0,
"hop_query_similarity_exceeded": 0,
"failed_prog_assertions": 0,
}
def validate_context_and_answer_and_hops(example, pred, trace=None):
global failure_counts
try:
if not dspy.evaluate.answer_exact_match(example, pred):
failure_counts["answer_exact_match"] += 1
return False
if not dspy.evaluate.answer_passage_match(example, pred):
failure_counts["answer_passage_match"] += 1
return False
hops = [example.question] + [
outputs.query for *_, outputs in trace if "query" in outputs
]
if max([len(h) for h in hops]) > 100:
failure_counts["max_hop_length_exceeded"] += 1
return False
if any(
dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8)
for idx in range(2, len(hops))
):
failure_counts["hop_query_similarity_exceeded"] += 1
return False
except:
failure_counts["failed_prog_assertions"] += 1
print("failed prog assertions:", example.question)
return False
return True
def validate_query_distinction_local(previous_queries, query):
"""check if query is distinct from previous queries"""
if previous_queries == []:
return True
if dspy.evaluate.answer_exact_match_str(query, previous_queries, frac=0.8):
return False
return True
# declaration of dspy program
class SimplifiedBaleen(dspy.Module):
def __init__(self, passages_per_hop=2, max_hops=2):
super().__init__()
self.generate_query = [
dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)
]
self.retrieve = dspy.Retrieve(k=passages_per_hop)
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
self.max_hops = max_hops
def forward(self, question):
context = []
prev_queries = [question]
for hop in range(self.max_hops):
query = self.generate_query[hop](context=context, question=question).query
dspy.Suggest(
len(query) <= 100,
"Query should be short and less than 100 characters",
)
dspy.Suggest(
validate_query_distinction_local(prev_queries, query),
"Query should be distinct from: "
+ "; ".join(f"{i+1}) {q}" for i, q in enumerate(prev_queries)),
)
prev_queries.append(query)
passages = self.retrieve(query).passages
context = deduplicate(context + passages)
pred = self.generate_answer(context=context, question=question)
return dspy.Prediction(context=context, answer=pred.answer)
teleprompter = BootstrapFewShot(metric=validate_context_and_answer_and_hops)
compiled_baleen = teleprompter.compile(
student=SimplifiedBaleen().map_named_predictors(Retry),
teacher=SimplifiedBaleen(passages_per_hop=2).map_named_predictors(Retry),
trainset=trainset,
)
print("=" * 50, "Validation Failures", "=" * 50)
for failure_type, count in failure_counts.items():
print(f"{failure_type}: {count}")