Skip to content

Commit

Permalink
upgrade: add web search! (run-llama#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu authored Nov 30, 2023
1 parent cd0d0c6 commit b028ab7
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 28 deletions.
2 changes: 2 additions & 0 deletions 1_🏠_Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
"To build a new agent, please make sure that 'Create a new agent' is selected.",
icon="ℹ️",
)
if "metaphor_key" in st.secrets:
st.info("**NOTE**: The ability to add web search is enabled.")


add_sidebar()
Expand Down
153 changes: 129 additions & 24 deletions agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from constants import AGENT_CACHE_DIR
import shutil

from llama_index.callbacks import CallbackManager
from callback_manager import StreamlitFunctionsCallbackHandler


def _resolve_llm(llm_str: str) -> LLM:
"""Resolve LLM."""
Expand Down Expand Up @@ -153,9 +156,25 @@ def load_agent(
"""Load agent."""
extra_kwargs = extra_kwargs or {}
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
# TODO: use default msg handler
# TODO: separate this from agent_utils.py...
def _msg_handler(msg: str) -> None:
"""Message handler."""
st.info(msg)
st.session_state.agent_messages.append(
{"role": "assistant", "content": msg, "msg_type": "info"}
)

# add streamlit callbacks (to inject events)
handler = StreamlitFunctionsCallbackHandler(_msg_handler)
callback_manager = CallbackManager([handler])
# get OpenAI Agent
agent: BaseChatEngine = OpenAIAgent.from_tools(
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
tools=tools,
llm=llm,
system_prompt=system_prompt,
**kwargs,
callback_manager=callback_manager,
)
else:
if "vector_index" not in extra_kwargs:
Expand Down Expand Up @@ -189,8 +208,12 @@ def load_meta_agent(
extra_kwargs = extra_kwargs or {}
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
# get OpenAI Agent

agent: BaseAgent = OpenAIAgent.from_tools(
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
tools=tools,
llm=llm,
system_prompt=system_prompt,
**kwargs,
)
else:
agent = ReActAgent.from_tools(
Expand Down Expand Up @@ -285,6 +308,66 @@ def construct_agent(
return agent, extra_info


def get_web_agent_tool() -> QueryEngineTool:
"""Get web agent tool.
Wrap with our load and search tool spec.
"""
from llama_hub.tools.metaphor.base import MetaphorToolSpec

# TODO: set metaphor API key
metaphor_tool = MetaphorToolSpec(
api_key=st.secrets.metaphor_key,
)
metaphor_tool_list = metaphor_tool.to_tool_list()

# TODO: LoadAndSearch doesn't work yet
# The search_and_retrieve_documents tool is the third in the tool list,
# as seen above
# wrapped_retrieve = LoadAndSearchToolSpec.from_defaults(
# metaphor_tool_list[2],
# )

# NOTE: requires openai right now
# We don't give the Agent our unwrapped retrieve document tools
# instead passing the wrapped tools
web_agent = OpenAIAgent.from_tools(
# [*wrapped_retrieve.to_tool_list(), metaphor_tool_list[4]],
metaphor_tool_list,
llm=BUILDER_LLM,
verbose=True,
)

# return agent as a tool
# TODO: tune description
web_agent_tool = QueryEngineTool.from_defaults(
web_agent,
name="web_agent",
description="""
This agent can answer questions by searching the web. \
Use this tool if the answer is ONLY likely to be found by searching \
the internet, especially for queries about recent events.
""",
)

return web_agent_tool


def get_tool_objects(tool_names: List[str]) -> List:
"""Get tool objects from tool names."""
# construct additional tools
tool_objs = []
for tool_name in tool_names:
if tool_name == "web_search":
# build web agent
tool_objs.append(get_web_agent_tool())
else:
raise ValueError(f"Tool {tool_name} not recognized.")

return tool_objs


class ParamCache(BaseModel):
"""Cache for RAG agent builder.
Expand Down Expand Up @@ -338,7 +421,7 @@ def save_to_disk(self, save_dir: str) -> None:
"file_names": self.file_names,
"urls": self.urls,
# TODO: figure out tools
# "tools": [],
"tools": self.tools,
"rag_params": self.rag_params.dict(),
"agent_id": self.agent_id,
}
Expand Down Expand Up @@ -376,11 +459,13 @@ def load_from_disk(
file_names=cache_dict["file_names"], urls=cache_dict["urls"]
)
# load agent from index
additional_tools = get_tool_objects(cache_dict["tools"])
agent, _ = construct_agent(
cache_dict["system_prompt"],
cache_dict["rag_params"],
cache_dict["docs"],
vector_index=vector_index,
additional_tools=additional_tools,
# TODO: figure out tools
)
cache_dict["vector_index"] = vector_index
Expand Down Expand Up @@ -505,20 +590,14 @@ def load_data(
self._cache.urls = urls
return "Data loaded successfully."

# NOTE: unused
def add_web_tool(self) -> str:
"""Add a web tool to enable agent to solve a task."""
# TODO: make this not hardcoded to a web tool
# Set up Metaphor tool
from llama_hub.tools.metaphor.base import MetaphorToolSpec

# TODO: set metaphor API key
metaphor_tool = MetaphorToolSpec(
api_key=os.environ["METAPHOR_API_KEY"],
)
metaphor_tool_list = metaphor_tool.to_tool_list()

self._cache.tools.extend(metaphor_tool_list)
if "web_search" in self._cache.tools:
return "Web tool already added."
else:
self._cache.tools.append("web_search")
return "Web tool added successfully."

def get_rag_params(self) -> Dict:
Expand Down Expand Up @@ -557,11 +636,13 @@ def create_agent(self, agent_id: Optional[str] = None) -> str:
if self._cache.system_prompt is None:
raise ValueError("Must set system prompt before creating agent.")

# construct additional tools
additional_tools = get_tool_objects(self.cache.tools)
agent, extra_info = construct_agent(
cast(str, self._cache.system_prompt),
cast(RAGParams, self._cache.rag_params),
self._cache.docs,
additional_tools=self._cache.tools,
additional_tools=additional_tools,
)

# if agent_id not specified, randomly generate one
Expand All @@ -587,6 +668,7 @@ def update_agent(
chunk_size: Optional[int] = None,
embed_model: Optional[str] = None,
llm: Optional[str] = None,
additional_tools: Optional[List] = None,
) -> None:
"""Update agent.
Expand All @@ -609,7 +691,6 @@ def update_agent(
# We call set_rag_params and create_agent, which will
# update the cache
# TODO: decouple functions from tool functions exposed to the agent

rag_params_dict: Dict[str, Any] = {}
if include_summarization is not None:
rag_params_dict["include_summarization"] = include_summarization
Expand All @@ -623,6 +704,11 @@ def update_agent(
rag_params_dict["llm"] = llm

self.set_rag_params(**rag_params_dict)

# update tools
if additional_tools is not None:
self.cache.tools = additional_tools

# this will update the agent in the cache
self.create_agent()

Expand Down Expand Up @@ -655,6 +741,33 @@ def update_agent(
# please make sure to update the LLM above if you change the function below


def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
"""Get list of builder agent tools to pass to the builder agent."""
# see if metaphor api key is set, otherwise don't add web tool
# TODO: refactor this later

if "metaphor_key" in st.secrets:
fns: List[Callable] = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.add_web_tool,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
else:
fns = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]

fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
return fn_tools


# define agent
# @st.cache_resource
def load_meta_agent_and_tools(
Expand All @@ -664,15 +777,7 @@ def load_meta_agent_and_tools(
# think of this as tools for the agent to use
agent_builder = RAGAgentBuilder(cache)

fns: List[Callable] = [
agent_builder.create_system_prompt,
agent_builder.load_data,
# add_web_tool,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns]
fn_tools = _get_builder_agent_tools(agent_builder)

builder_agent = load_meta_agent(
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
Expand Down
70 changes: 70 additions & 0 deletions callback_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Streaming callback manager."""
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType

from typing import Optional, Dict, Any, List, Callable

STORAGE_DIR = "./storage" # directory to cache the generated index
DATA_DIR = "./data" # directory containing the documents to index


class StreamlitFunctionsCallbackHandler(BaseCallbackHandler):
"""Callback handler that outputs streamlit components given events."""

def __init__(self, msg_handler: Callable[[str], Any]) -> None:
"""Initialize the base callback handler."""
self.msg_handler = msg_handler
super().__init__([], [])

def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
if event_type == CBEventType.FUNCTION_CALL:
if payload is None:
raise ValueError("Payload cannot be None")
arguments_str = payload["function_call"]
tool_str = payload["tool"].name
print_str = f"Calling function: {tool_str} with args: {arguments_str}\n\n"
self.msg_handler(print_str)
else:
pass
return event_id

def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""
pass
# TODO: currently we don't need to do anything here
# if event_type == CBEventType.FUNCTION_CALL:
# response = payload["function_call_response"]
# # Add this to queue
# print_str = (
# f"\n\nGot output: {response}\n"
# "========================\n\n"
# )
# elif event_type == CBEventType.AGENT_STEP:
# # put response into queue
# self._queue.put(payload["response"])

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
pass

def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
pass
10 changes: 10 additions & 0 deletions pages/2_⚙️_RAG_Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def update_agent() -> None:
"config_agent_builder" in st.session_state.keys()
and st.session_state.config_agent_builder is not None
):
additional_tools = st.session_state.additional_tools_st.split(",")
agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder)
### Update the agent
agent_builder.update_agent(
Expand All @@ -34,6 +35,7 @@ def update_agent() -> None:
chunk_size=st.session_state.chunk_size_st,
embed_model=st.session_state.embed_model_st,
llm=st.session_state.llm_st,
additional_tools=additional_tools,
)

# Update Radio Buttons: update selected agent to the new id
Expand Down Expand Up @@ -114,6 +116,14 @@ def delete_agent() -> None:
value=rag_params.include_summarization,
key="include_summarization_st",
)

# add web tool
additional_tools_st = st.text_input(
"Additional tools (currently only supports 'web_search')",
value=",".join(agent_builder.cache.tools),
key="additional_tools_st",
)

top_k_st = st.number_input("Top K", value=rag_params.top_k, key="top_k_st")
chunk_size_st = st.number_input(
"Chunk Size", value=rag_params.chunk_size, key="chunk_size_st"
Expand Down
Loading

0 comments on commit b028ab7

Please sign in to comment.