Skip to content

Commit

Permalink
rework 09
Browse files Browse the repository at this point in the history
  • Loading branch information
Konsti-s committed May 27, 2024
1 parent e698708 commit b3027f7
Showing 1 changed file with 149 additions and 84 deletions.
233 changes: 149 additions & 84 deletions 09_agent_with_sql_toolkit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"id": "1",
"metadata": {},
"source": [
"Zuerst enpacken wir unsere Demo Datenbank mit der wir gleich arfbeiten werden."
"Zuerst enpacken wir unsere Demo Datenbank mit der wir gleich arbeiten werden.\n"
]
},
{
Expand Down Expand Up @@ -41,14 +41,14 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts.chat import ChatPromptTemplate\n",
"from langchain_community.agent_toolkits.sql.base import SQLDatabaseToolkit\n",
"from langchain_community.utilities.sql_database import SQLDatabase\n",
"from helpers import llm\n",
"\n",
"model = llm(temperature=0)\n",
"db = SQLDatabase.from_uri(\"sqlite:///northwind.db\")\n",
"toolkit = SQLDatabaseToolkit(db=db, llm=model)"
"toolkit = SQLDatabaseToolkit(db=db, llm=model)\n",
"tools = toolkit.get_tools()"
]
},
{
Expand All @@ -58,28 +58,43 @@
"metadata": {},
"outputs": [],
"source": [
"from langgraph.prebuilt import create_agent_executor\n",
"from langchain.agents import create_openai_functions_agent\n",
"from langchain import hub\n",
"from langchain.agents import create_react_agent\n",
"from langchain.schema import AIMessage, SystemMessage\n",
"from langchain_core.runnables import chain\n",
"\n",
"system_message_prompt = \"\"\"You are an agent designed to interact with a SQL database.\n",
"Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
"You have access to tools for interacting with the database.\n",
"Only use the below tools. Only use the information returned by the below tools to construct your final answer.\n",
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\"\"\"\n",
"\n",
"tools = toolkit.get_tools()\n",
"prompt: ChatPromptTemplate = hub.pull(\"reactagent/sql\")\n",
"prompt = prompt.partial(dialect=toolkit.dialect, top_k=10)\n",
"agent_runnable = create_openai_functions_agent(model, tools, prompt)\n",
"agent_excutor = create_agent_executor(agent_runnable, tools)\n",
"\n",
"prompt"
"@chain\n",
"def messages_modifier(messages):\n",
" return [\n",
" SystemMessage(system_message_prompt.format(dialect=toolkit.dialect, top_k=10)),\n",
" messages[0],\n",
" AIMessage(\n",
" \"I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.\"\n",
" ),\n",
" *messages[1:],\n",
" ]\n",
"\n",
"\n",
"agent_executor = create_react_agent(model, tools, messages_modifier)"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"#### Wir definieren eine Funktion, die die gestreamte Ausgabe des Agenten formatiert.\n",
"\n",
"Man muss den folgenden Code nicht durchlesen. Man sollte sich nur merken, dass man solche Code-Brocken üblicherweise selbst erstellen muss.\n"
"#### Looos....\n"
]
},
{
Expand All @@ -89,142 +104,192 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, AsyncIterator, List, Tuple\n",
"from langchain_core.agents import AgentActionMessageLog, AgentFinish\n",
"\n",
"\n",
"async def formatted_output_streamer(stream: AsyncIterator[Any]) -> AsyncIterator[Any]:\n",
" async for chunk in stream:\n",
" output = \"\"\n",
" for key, value in chunk.items():\n",
" if key == \"agent\":\n",
" outcome = value.get(\"agent_outcome\")\n",
" if isinstance(outcome, AgentActionMessageLog):\n",
" output += f\"Agent log:\\n\\n{outcome.log.strip()}\"\n",
" elif isinstance(outcome, AgentFinish):\n",
" output += f\"Agent finished:\\n\\n{outcome.log.strip()}\"\n",
" output += \"\\n\\n----------------------------------------------------------------------------------------\\n\\n\"\n",
" elif key == \"action\":\n",
" steps: List[Tuple[AgentActionMessageLog, str]] = value.get(\n",
" \"intermediate_steps\"\n",
" )\n",
" for index, step in enumerate(steps):\n",
" output += f\"Tool log:\\n\\n{step[1].strip()}\"\n",
" if index < len(steps) - 1:\n",
" print(\"----------------\")\n",
" output += \"\\n\\n----------------------------------------------------------------------------------------\\n\\n\"\n",
" elif key == \"__end__\":\n",
" output = \"Done\"\n",
" yield output"
"from langchain.schema import HumanMessage\n",
"\n",
"\n",
"input_1 = {\"messages\": [HumanMessage(content=\"Where do I find the orders?\")]}\n",
"for chunk in agent_executor.stream(input_1):\n",
" for state in chunk.values():\n",
" for message in state[\"messages\"]:\n",
" message.pretty_print()"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"#### Wir pipen (chainen) den Agenten mit dem Formatierer\n"
"input_2 = {\"messages\": [HumanMessage(content=\"Which Employee has the most orders?\")]}\n",
"for chunk in agent_executor.stream(input_2):\n",
" for state in chunk.values():\n",
" for message in state[\"messages\"]:\n",
" message.pretty_print()"
]
},
{
"cell_type": "markdown",
"id": "9",
"metadata": {},
"source": [
"#### Schauen wir mal, ob er das hier hinbekommt.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"app = agent_excutor | formatted_output_streamer"
"input_3 = {\n",
" \"messages\": [\n",
" HumanMessage(\n",
" content=\"Which Customer has had the Order with the highest total cost ever? What was the Order Id?\"\n",
" )\n",
" ]\n",
"}\n",
"for chunk in agent_executor.stream(input_3):\n",
" for state in chunk.values():\n",
" for message in state[\"messages\"]:\n",
" message.pretty_print()"
]
},
{
"cell_type": "markdown",
"id": "10",
"id": "11",
"metadata": {},
"source": [
"#### Looos....\n"
"#### Und noch einmal das batchen (async) demonstrieren.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"inputs = {\"input\": \"Where do i find the orders?\"}\n",
"async for chunk in app.astream(inputs):\n",
" print(chunk)"
"from typing import Dict\n",
"\n",
"\n",
"async def format_output(item: Dict) -> str:\n",
" return [item.get(\"messages\")[0].content, item.get(\"messages\")[-1].content]\n",
"\n",
"\n",
"batcher = agent_executor | format_output"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"id": "13",
"metadata": {},
"outputs": [],
"source": [
"inputs = {\"input\": \"Which Employee has the most orders?\"}\n",
"async for chunk in app.astream(inputs):\n",
" print(chunk)"
"result = await batcher.abatch([input_1, input_2, input_3])\n",
"\n",
"\n",
"for index, item in enumerate(result):\n",
" print(f\"Query {index+1}:\")\n",
" print(f\"Question: {item[0]}\")\n",
" print(f\"Answer: {item[1]}\\n\\n\")"
]
},
{
"cell_type": "markdown",
"id": "13",
"id": "14",
"metadata": {},
"source": [
"#### Schauen wir mal, ob er das hier hinbekommt.\n"
"## ✅ Aufgabe\n",
"\n",
"### Schaut euch die DB an und stellt eine komplizierte Frage. Schaut mal, wie weit das LLM kommt.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"id": "15",
"metadata": {},
"outputs": [],
"source": [
"inputs = {\n",
" \"input\": \"Which Customer has had the Order with the highest total cost ever? What was the Order Id?\"\n",
"}\n",
"async for chunk in app.astream(inputs):\n",
" print(chunk)"
"print(db.get_table_info())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"metadata": {},
"outputs": [],
"source": [
"your_input = {\"messages\": [HumanMessage(content=\"\")]}\n",
"for chunk in agent_executor.stream(your_input):\n",
" for state in chunk.values():\n",
" for message in state[\"messages\"]:\n",
" message.pretty_print()"
]
},
{
"cell_type": "markdown",
"id": "15",
"id": "17",
"metadata": {},
"source": [
"#### Und noch einmal das batchen demonstrieren.\n"
"### Schaut euch den Prompt an und spielt damit herum\n",
"\n",
"Vielleicht bringt ihr das LLM dazu, Fehler zu machen oder sich komisch zu erhalten? Vielleicht kreiert ihr lustige Ergebnisse?\n",
"Dies ist ein \"echter\" Prompt, der schon relativ gut ist in dem, was er tut.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"id": "18",
"metadata": {},
"outputs": [],
"source": [
"async def formatted_output_batcher(item: Any) -> str:\n",
" return [item.get(\"input\"), item.get(\"agent_outcome\").return_values.get(\"output\")]\n",
"\n",
"custom_system_message_prompt = \"\"\"You are an agent designed to interact with a SQL database.\n",
"Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
"You have access to tools for interacting with the database.\n",
"Only use the below tools. Only use the information returned by the below tools to construct your final answer.\n",
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\"\"\"\n",
"\n",
"batcher = agent_excutor | formatted_output_batcher\n",
"\n",
"result = await batcher.abatch(\n",
" [\n",
" {\"input\": \"Where do i find the orders?\"},\n",
" {\"input\": \"Which Employee has the most orders?\"},\n",
" {\n",
" \"input\": \"Which Customer has had the Order with the highest total cost ever? What was the Order Id?\"\n",
" },\n",
"@chain\n",
"def custom_messages_modifier(messages):\n",
" return [\n",
" SystemMessage(\n",
" custom_system_message_prompt.format(dialect=toolkit.dialect, top_k=10)\n",
" ),\n",
" messages[0],\n",
" AIMessage(\n",
" \"I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.\"\n",
" ),\n",
" *messages[1:],\n",
" ]\n",
")\n",
"\n",
"for index, item in enumerate(result):\n",
" print(f\"Query {index+1}:\")\n",
" print(f\"Question: {item[0]}\")\n",
" print(f\"Answer: {item[1]}\\n\\n\")"
"\n",
"agent_executor = create_react_agent(model, tools, custom_messages_modifier)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19",
"metadata": {},
"outputs": [],
"source": [
"your_input = {\"messages\": [HumanMessage(content=\"\")]}\n",
"for chunk in agent_executor.stream(your_input):\n",
" for state in chunk.values():\n",
" for message in state[\"messages\"]:\n",
" message.pretty_print()"
]
}
],
Expand All @@ -244,7 +309,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down

0 comments on commit b3027f7

Please sign in to comment.