Skip to content

Commit 5ce7bc7

Browse files
committed
Changes to eval for conf data
1 parent 01a287f commit 5ce7bc7

14 files changed

+228
-172
lines changed

docs/evaluation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ pip install -r requirements-dev.txt
4242

4343
## Generate ground truth data
4444

45-
Modify the prompt in `evals/generate.txt` to match your database table and RAG scenario.
45+
Modify the prompt in `evals/generate_prompt.txt` to match your database table and RAG scenario.
4646

4747
Generate ground truth data by running the following command:
4848

4949
```bash
50-
python evals/generate_ground_truth_data.py
50+
python evals/generate_ground_truth.py --numquestions=50 --persource=50
5151
```
5252

5353
Review the generated data after running that script, removing any question/answer pairs that don't seem like realistic user input.

evals/eval_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"testdata_path": "ground_truth.jsonl",
33
"results_dir": "results/experiment<TIMESTAMP>",
4-
"requested_metrics": ["gpt_groundedness", "gpt_relevance", "answer_length", "latency", "citations_matched"],
4+
"requested_metrics": ["gpt_groundedness", "gpt_relevance", "f1_score", "answer_length", "latency", "citations_matched"],
55
"target_url": "http://127.0.0.1:8000/chat",
66
"target_parameters": {
77
"overrides": {

evals/evaluate.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,13 @@ def citations_overlap(*, response, ground_truth, **kwargs):
2424
if response is None:
2525
logger.warning("Received response of None, can't compute citation_match metric. Setting to -1.")
2626
return {cls.METRIC_NAME: -1}
27-
truth_citations = set(re.findall(r"\[(\d+)\]", ground_truth))
28-
response_citations = set(re.findall(r"\[(\d+)\]", response))
29-
# Count the percentage of citations that are present in the response
27+
citation_pattern = r"\[(\d+)\]"
28+
truth_citations = set(re.findall(citation_pattern, ground_truth))
29+
response_citations = set(re.findall(citation_pattern, response))
30+
# Return the percentage of citations that are present in the response
31+
if len(truth_citations) == 0:
32+
logger.warning("No citations found in ground truth, setting metric to 1.0.")
33+
return {cls.METRIC_NAME: 1.0}
3034
num_citations = len(truth_citations)
3135
num_matched_citations = len(truth_citations.intersection(response_citations))
3236
return {cls.METRIC_NAME: num_matched_citations / num_citations}
@@ -74,8 +78,10 @@ def get_openai_config() -> dict:
7478

7579
if __name__ == "__main__":
7680
logging.basicConfig(
77-
level=logging.INFO, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
81+
level=logging.WARNING, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
7882
)
83+
logging.getLogger("evaltools").setLevel(logging.INFO)
84+
logger.setLevel(logging.INFO)
7985
load_dotenv(".env", override=True)
8086

8187
parser = argparse.ArgumentParser(description="Run evaluation with OpenAI configuration.")

evals/generate_ground_truth.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import json
23
import logging
34
import os
@@ -8,14 +9,18 @@
89
from dotenv_azd import load_azd_env
910
from openai import AzureOpenAI, OpenAI
1011
from openai.types.chat import ChatCompletionToolParam
11-
from sqlalchemy import create_engine, select
12+
from sqlalchemy import create_engine, select, func
1213
from sqlalchemy.orm import Session
14+
from dotenv import load_dotenv
15+
from jinja2 import Environment, FileSystemLoader
16+
from rich.logging import RichHandler
1317

1418
from fastapi_app.postgres_models import Item
1519

1620
logger = logging.getLogger("ragapp")
1721

1822

23+
1924
def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:
2025
return {
2126
"type": "function",
@@ -47,26 +52,19 @@ def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:
4752

4853

4954
def source_retriever() -> Generator[str, None, None]:
50-
# Connect to the database
55+
# Connect to the local database
5156
DBHOST = os.environ["POSTGRES_HOST"]
5257
DBUSER = os.environ["POSTGRES_USERNAME"]
5358
DBPASS = os.environ["POSTGRES_PASSWORD"]
5459
DBNAME = os.environ["POSTGRES_DATABASE"]
5560
DATABASE_URI = f"postgresql://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}"
5661
engine = create_engine(DATABASE_URI, echo=False)
5762
with Session(engine) as session:
58-
# Fetch all products for a particular type
59-
item_types = session.scalars(select(Item.type).distinct())
60-
for item_type in item_types:
61-
records = list(session.scalars(select(Item).filter(Item.type == item_type).order_by(Item.id)))
62-
logger.info(f"Processing database records for type: {item_type}")
63-
yield "\n\n".join([f"## Product ID: [{record.id}]\n" + record.to_str_for_rag() for record in records])
64-
# Fetch each item individually
65-
# records = list(session.scalars(select(Item).order_by(Item.id)))
66-
# for record in records:
67-
# logger.info(f"Processing database record: {record.name}")
68-
# yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()
69-
# await self.openai_chat_client.chat.completions.create(
63+
while True:
64+
# Fetch all the rows from the database
65+
random_rows = list(session.scalars(select(Item).order_by(func.random())))
66+
logger.info("Fetched %d random rows", len(random_rows))
67+
yield "\n\n".join([f"## Row ID: [{row.id}]\n" + row.to_str_for_rag() for row in random_rows])
7068

7169

7270
def source_to_text(source) -> str:
@@ -108,31 +106,36 @@ def get_openai_client() -> tuple[AzureOpenAI | OpenAI, str]:
108106
return openai_client, model
109107

110108

111-
def generate_ground_truth_data(num_questions_total: int, num_questions_per_source: int = 5):
109+
def generate_ground_truth_data(num_questions_total: int, num_questions_per_source):
112110
logger.info("Generating %d questions total", num_questions_total)
113111
openai_client, model = get_openai_client()
114112
current_dir = Path(__file__).parent
115-
generate_prompt = open(current_dir / "generate_prompt.txt").read()
113+
114+
# Load the template from the file system
115+
jinja_file_loader = FileSystemLoader(current_dir)
116+
jinja_env = Environment(loader=jinja_file_loader)
117+
prompt_template = jinja_env.get_template('generate_prompt.jinja2')
118+
116119
output_file = Path(__file__).parent / "ground_truth.jsonl"
117120

118121
qa: list[dict] = []
119-
for source in source_retriever():
120-
if len(qa) > num_questions_total:
121-
logger.info("Generated enough questions already, stopping")
122-
break
122+
while len(qa) < num_questions_total:
123+
sources = next(source_retriever())
124+
previous_questions = [qa_pair["question"] for qa_pair in qa]
123125
result = openai_client.chat.completions.create(
124126
model=model,
125127
messages=[
126-
{"role": "system", "content": generate_prompt},
127-
{"role": "user", "content": json.dumps(source)},
128+
{"role": "system", "content": prompt_template.render(num_questions=num_questions_per_source, previous_questions=previous_questions)},
129+
{"role": "user", "content": json.dumps(sources)},
128130
],
129-
tools=[qa_pairs_tool(num_questions=2)],
131+
tools=[qa_pairs_tool(num_questions=num_questions_per_source)],
130132
)
131133
if not result.choices[0].message.tool_calls:
132134
logger.warning("No tool calls found in response, skipping")
133135
continue
134136
qa_pairs = json.loads(result.choices[0].message.tool_calls[0].function.arguments)["qa_list"]
135137
qa_pairs = [{"question": qa_pair["question"], "truth": qa_pair["answer"]} for qa_pair in qa_pairs]
138+
logger.info("Received %d suggested questions", len(qa_pairs))
136139
qa.extend(qa_pairs)
137140

138141
logger.info("Writing %d questions to %s", num_questions_total, output_file)
@@ -145,8 +148,16 @@ def generate_ground_truth_data(num_questions_total: int, num_questions_per_sourc
145148

146149

147150
if __name__ == "__main__":
148-
logging.basicConfig(level=logging.WARNING)
151+
logging.basicConfig(
152+
level=logging.WARNING, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
153+
)
149154
logger.setLevel(logging.INFO)
150-
load_azd_env()
155+
load_dotenv(".env", override=True)
156+
157+
parser = argparse.ArgumentParser(description="Run evaluation with OpenAI configuration.")
158+
parser.add_argument("--numquestions", type=int, help="Specify the number of questions.", default=50)
159+
parser.add_argument("--persource", type=int, help="Specify the number of questions per retrieved sources.", default=5)
160+
161+
args = parser.parse_args()
151162

152-
generate_ground_truth_data(num_questions_total=10)
163+
generate_ground_truth_data(num_questions_total=args.numquestions, num_questions_per_source=args.persource)

evals/generate_prompt.jinja2

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
Your job is to generate {{ num_questions }} example questions that a customer might ask about sessions at the GitHub Universe conference.
2+
The conference has *not* yet happened.
3+
4+
You should come up with the {{ num_questions }} questions and answers based on the provided data.
5+
Each answer should include the row ID in square brackets.
6+
For example,
7+
'Are there any sessions featuring Python?'
8+
with answer:
9+
'Yes, there is a session on Python at 10:00 AM on the first day, about how to use Python to automate your workflow. [12]
10+
There is an additional session at 2:00 PM on the second day about how to use Python to build a web application. [5]
11+
Finally, there is a session at 4:00 PM on the second day about how to use Python to analyze data. [3]'
12+
'
13+
Your answer should typically be a paragraph or two.
14+
15+
Your questions should NOT be about specific session titles, but instead be more general questions
16+
that a conference attendee might ask when planning their schedule.
17+
Your answers should reference specific session titles, however, to help the user pick sessions.
18+
19+
{% if previous_questions %}
20+
You should NOT suggest any of these questions that have already been asked:
21+
{{ previous_questions }}
22+
{% endif %}

evals/generate_prompt.txt

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)