1
+ import argparse
1
2
import json
2
3
import logging
3
4
import os
8
9
from dotenv_azd import load_azd_env
9
10
from openai import AzureOpenAI , OpenAI
10
11
from openai .types .chat import ChatCompletionToolParam
11
- from sqlalchemy import create_engine , select
12
+ from sqlalchemy import create_engine , select , func
12
13
from sqlalchemy .orm import Session
14
+ from dotenv import load_dotenv
15
+ from jinja2 import Environment , FileSystemLoader
16
+ from rich .logging import RichHandler
13
17
14
18
from fastapi_app .postgres_models import Item
15
19
16
20
logger = logging .getLogger ("ragapp" )
17
21
18
22
23
+
19
24
def qa_pairs_tool (num_questions : int = 1 ) -> ChatCompletionToolParam :
20
25
return {
21
26
"type" : "function" ,
@@ -47,26 +52,19 @@ def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:
47
52
48
53
49
54
def source_retriever () -> Generator [str , None , None ]:
50
- # Connect to the database
55
+ # Connect to the local database
51
56
DBHOST = os .environ ["POSTGRES_HOST" ]
52
57
DBUSER = os .environ ["POSTGRES_USERNAME" ]
53
58
DBPASS = os .environ ["POSTGRES_PASSWORD" ]
54
59
DBNAME = os .environ ["POSTGRES_DATABASE" ]
55
60
DATABASE_URI = f"postgresql://{ DBUSER } :{ DBPASS } @{ DBHOST } /{ DBNAME } "
56
61
engine = create_engine (DATABASE_URI , echo = False )
57
62
with Session (engine ) as session :
58
- # Fetch all products for a particular type
59
- item_types = session .scalars (select (Item .type ).distinct ())
60
- for item_type in item_types :
61
- records = list (session .scalars (select (Item ).filter (Item .type == item_type ).order_by (Item .id )))
62
- logger .info (f"Processing database records for type: { item_type } " )
63
- yield "\n \n " .join ([f"## Product ID: [{ record .id } ]\n " + record .to_str_for_rag () for record in records ])
64
- # Fetch each item individually
65
- # records = list(session.scalars(select(Item).order_by(Item.id)))
66
- # for record in records:
67
- # logger.info(f"Processing database record: {record.name}")
68
- # yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()
69
- # await self.openai_chat_client.chat.completions.create(
63
+ while True :
64
+ # Fetch all the rows from the database
65
+ random_rows = list (session .scalars (select (Item ).order_by (func .random ())))
66
+ logger .info ("Fetched %d random rows" , len (random_rows ))
67
+ yield "\n \n " .join ([f"## Row ID: [{ row .id } ]\n " + row .to_str_for_rag () for row in random_rows ])
70
68
71
69
72
70
def source_to_text (source ) -> str :
@@ -108,31 +106,36 @@ def get_openai_client() -> tuple[AzureOpenAI | OpenAI, str]:
108
106
return openai_client , model
109
107
110
108
111
- def generate_ground_truth_data (num_questions_total : int , num_questions_per_source : int = 5 ):
109
+ def generate_ground_truth_data (num_questions_total : int , num_questions_per_source ):
112
110
logger .info ("Generating %d questions total" , num_questions_total )
113
111
openai_client , model = get_openai_client ()
114
112
current_dir = Path (__file__ ).parent
115
- generate_prompt = open (current_dir / "generate_prompt.txt" ).read ()
113
+
114
+ # Load the template from the file system
115
+ jinja_file_loader = FileSystemLoader (current_dir )
116
+ jinja_env = Environment (loader = jinja_file_loader )
117
+ prompt_template = jinja_env .get_template ('generate_prompt.jinja2' )
118
+
116
119
output_file = Path (__file__ ).parent / "ground_truth.jsonl"
117
120
118
121
qa : list [dict ] = []
119
- for source in source_retriever ():
120
- if len (qa ) > num_questions_total :
121
- logger .info ("Generated enough questions already, stopping" )
122
- break
122
+ while len (qa ) < num_questions_total :
123
+ sources = next (source_retriever ())
124
+ previous_questions = [qa_pair ["question" ] for qa_pair in qa ]
123
125
result = openai_client .chat .completions .create (
124
126
model = model ,
125
127
messages = [
126
- {"role" : "system" , "content" : generate_prompt },
127
- {"role" : "user" , "content" : json .dumps (source )},
128
+ {"role" : "system" , "content" : prompt_template . render ( num_questions = num_questions_per_source , previous_questions = previous_questions ) },
129
+ {"role" : "user" , "content" : json .dumps (sources )},
128
130
],
129
- tools = [qa_pairs_tool (num_questions = 2 )],
131
+ tools = [qa_pairs_tool (num_questions = num_questions_per_source )],
130
132
)
131
133
if not result .choices [0 ].message .tool_calls :
132
134
logger .warning ("No tool calls found in response, skipping" )
133
135
continue
134
136
qa_pairs = json .loads (result .choices [0 ].message .tool_calls [0 ].function .arguments )["qa_list" ]
135
137
qa_pairs = [{"question" : qa_pair ["question" ], "truth" : qa_pair ["answer" ]} for qa_pair in qa_pairs ]
138
+ logger .info ("Received %d suggested questions" , len (qa_pairs ))
136
139
qa .extend (qa_pairs )
137
140
138
141
logger .info ("Writing %d questions to %s" , num_questions_total , output_file )
@@ -145,8 +148,16 @@ def generate_ground_truth_data(num_questions_total: int, num_questions_per_sourc
145
148
146
149
147
150
if __name__ == "__main__" :
148
- logging .basicConfig (level = logging .WARNING )
151
+ logging .basicConfig (
152
+ level = logging .WARNING , format = "%(message)s" , datefmt = "[%X]" , handlers = [RichHandler (rich_tracebacks = True )]
153
+ )
149
154
logger .setLevel (logging .INFO )
150
- load_azd_env ()
155
+ load_dotenv (".env" , override = True )
156
+
157
+ parser = argparse .ArgumentParser (description = "Run evaluation with OpenAI configuration." )
158
+ parser .add_argument ("--numquestions" , type = int , help = "Specify the number of questions." , default = 50 )
159
+ parser .add_argument ("--persource" , type = int , help = "Specify the number of questions per retrieved sources." , default = 5 )
160
+
161
+ args = parser .parse_args ()
151
162
152
- generate_ground_truth_data (num_questions_total = 10 )
163
+ generate_ground_truth_data (num_questions_total = args . numquestions , num_questions_per_source = args . persource )
0 commit comments