-
Notifications
You must be signed in to change notification settings - Fork 166
/
Copy pathturbo_main.py
139 lines (114 loc) · 4.36 KB
/
turbo_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from postgres_da_ai_agent.agents.turbo4 import Turbo4
from postgres_da_ai_agent.types import Chat, TurboTool
from typing import List, Callable
import os
from postgres_da_ai_agent.agents.instruments import PostgresAgentInstruments
from postgres_da_ai_agent.modules import llm
from postgres_da_ai_agent.modules import rand
from postgres_da_ai_agent.modules import embeddings
import argparse
DB_URL = os.environ.get("DATABASE_URL")
POSTGRES_TABLE_DEFINITIONS_CAP_REF = "TABLE_DEFINITIONS"
custom_function_tool_config = {
"type": "function",
"function": {
"name": "store_fact",
"description": "A function that stores a fact.",
"parameters": {
"type": "object",
"properties": {"fact": {"type": "string"}},
},
},
}
run_sql_tool_config = {
"type": "function",
"function": {
"name": "run_sql",
"description": "Run a SQL query against the postgres database",
"parameters": {
"type": "object",
"properties": {
"sql": {
"type": "string",
"description": "The SQL query to run",
}
},
"required": ["sql"],
},
},
}
def store_fact(fact: str):
print(f"------store_fact({fact})------")
return "Fact stored."
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", help="The prompt for the AI")
args = parser.parse_args()
if not args.prompt:
print("Please provide a prompt")
return
raw_prompt = args.prompt
prompt = f"Fulfill this database query: {raw_prompt}. "
assistant_name = "Turbo4"
assistant = Turbo4()
session_id = rand.generate_session_id(assistant_name + raw_prompt)
with PostgresAgentInstruments(DB_URL, session_id) as (agent_instruments, db):
database_embedder = embeddings.DatabaseEmbedder(db)
table_definitions = database_embedder.get_similar_table_defs_for_prompt(
raw_prompt
)
prompt = llm.add_cap_ref(
prompt,
f"Use these {POSTGRES_TABLE_DEFINITIONS_CAP_REF} to satisfy the database query.",
POSTGRES_TABLE_DEFINITIONS_CAP_REF,
table_definitions,
)
tools = [
TurboTool("run_sql", run_sql_tool_config, agent_instruments.run_sql),
]
(
assistant.get_or_create_assistant(assistant_name)
.set_instructions(
"You're an elite SQL developer. You generate the most concise and performant SQL queries."
)
.equip_tools(tools)
.make_thread()
.add_message(prompt)
.run_thread()
.add_message(
"Use the run_sql function to run the SQL you've just generated.",
)
.run_thread(toolbox=[tools[0].name])
.run_validation(agent_instruments.validate_run_sql)
.spy_on_assistant(agent_instruments.make_agent_chat_file(assistant_name))
.get_costs_and_tokens(
agent_instruments.make_agent_cost_file(assistant_name)
)
)
print(f"✅ Turbo4 Assistant finished.")
# ---------- Simple Prompt Solution - Same thing, only 2 api calls instead of 8+ ------------
# sql_response = llm.prompt(
# prompt,
# model="gpt-4-1106-preview",
# instructions="You're an elite SQL developer. You generate the most concise and performant SQL queries.",
# )
# llm.prompt_func(
# "Use the run_sql function to run the SQL you've just generated: "
# + sql_response,
# model="gpt-4-1106-preview",
# instructions="You're an elite SQL developer. You generate the most concise and performant SQL queries.",
# turbo_tools=tools,
# )
# agent_instruments.validate_run_sql()
# ----------- Example use case of Turbo4 and the Assistants API ------------
# (
# assistant.get_or_create_assistant(assistant_name)
# .make_thread()
# .equip_tools(tools)
# .add_message("Generate 10 random facts about LLM technology.")
# .run_thread()
# .add_message("Use the store_fact function to 1 fact.")
# .run_thread(toolbox=["store_fact"])
# )
if __name__ == "__main__":
main()