Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu committed Nov 18, 2023
1 parent 3b2bea2 commit d5a444f
Show file tree
Hide file tree
Showing 7 changed files with 943 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.ipynb_checkpoints
.streamlit
.venv
__pycache__
73 changes: 73 additions & 0 deletions 1_🏠_Home.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import streamlit as st
import os
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key

from streamlit_pills import pills

from agent_utils import (
load_meta_agent_and_tools,
ParamCache
)


####################
#### STREAMLIT #####
####################



st.set_page_config(page_title="Build a RAGs bot, powered by LlamaIndex", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
st.title("Build a RAGs bot, powered by LlamaIndex 💬🦙")
st.info(
"Use this page to build your RAG bot over your data! "
"Once the agent is finished creating, check out the `RAG Config` and `Generated RAG Agent` pages.",
icon="ℹ️"
)

#### load builder agent and its tool spec (the agent_builder)
builder_agent, agent_builder = load_meta_agent_and_tools()

if "builder_agent" not in st.session_state.keys():
st.session_state.builder_agent = builder_agent
if "agent_builder" not in st.session_state.keys():
st.session_state.agent_builder = agent_builder

# add pills
selected = pills(
"Outline your task!",
["I want to analyze this PDF file (data/invoices.pdf)",
"I want to search over my CSV documents."
], clearable=True, index=None
)

if "messages" not in st.session_state.keys(): # Initialize the chat messages history
st.session_state.messages = [
{"role": "assistant", "content": "What RAG bot do you want to build?"}
]

def add_to_message_history(role, content):
message = {"role": role, "content": str(content)}
st.session_state.messages.append(message) # Add response to message history

for message in st.session_state.messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])

# handle user input
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
add_to_message_history("user", prompt)
with st.chat_message("user"):
st.write(prompt)

# If last message is not from assistant, generate a new response
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = st.session_state.builder_agent.chat(prompt)
st.write(str(response))
add_to_message_history("assistant", response)

# # check cache
print(st.session_state.agent_builder.cache)
# if "agent" in cache:
# st.session_state.agent = cache["agent"]
308 changes: 308 additions & 0 deletions agent_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
from llama_index.llms import OpenAI, ChatMessage
from llama_index.llms.utils import resolve_llm
from pydantic import BaseModel, Field
import os
from llama_index.tools.query_engine import QueryEngineTool
from llama_index.agent import OpenAIAgent
from llama_index import (
VectorStoreIndex,
SummaryIndex,
ServiceContext,
Document
)
from llama_index.prompts import ChatPromptTemplate
from typing import List, cast, Optional
from llama_index import SimpleDirectoryReader
from llama_index.embeddings.utils import resolve_embed_model
from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool
from typing import Dict, Tuple
import streamlit as st


####################
#### META TOOLS ####
####################


# System prompt tool
GEN_SYS_PROMPT_STR = """\
Task information is given below.
Given the task, please generate a system prompt for an OpenAI-powered bot to solve this task:
{task} \
Make sure the system prompt obeys the following requirements:
- Tells the bot to ALWAYS use tools given to solve the task. NEVER give an answer without using a tool.
- Does not reference a specific data source. The data source is implicit in any queries to the bot,
and telling the bot to analyze a specific data source might confuse it given a
user query.
"""

gen_sys_prompt_messages = [
ChatMessage(
role="system",
content="You are helping to build a system prompt for another bot.",
),
ChatMessage(role="user", content=GEN_SYS_PROMPT_STR),
]

GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages)


class RAGParams(BaseModel):
"""RAG parameters.
Parameters used to configure a RAG pipeline.
"""
include_summarization: bool = Field(default=False, description="Whether to include summarization in the RAG pipeline.")
top_k: int = Field(default=2, description="Number of documents to retrieve from vector store.")
chunk_size: int = Field(default=1024, description="Chunk size for vector store.")
embed_model: str = Field(
default="default", description="Embedding model to use (default is OpenAI)"
)
llm: str = Field(default="gpt-4-1106-preview", description="LLM to use for summarization.")


class ParamCache(BaseModel):
"""Cache for RAG agent builder.
Created a wrapper class around a dict in case we wanted to more explicitly
type different items in the cache.
"""

# arbitrary types
class Config:
arbitrary_types_allowed = True

system_prompt: Optional[str] = Field(default=None, description="System prompt for RAG agent.")
docs: List[Document] = Field(default_factory=list, description="Documents for RAG agent.")
tools: List = Field(default_factory=list, description="Additional tools for RAG agent (e.g. web)")
rag_params: RAGParams = Field(default_factory=RAGParams, description="RAG parameters for RAG agent.")
agent: Optional[OpenAIAgent] = Field(default=None, description="RAG agent.")



class RAGAgentBuilder:
"""RAG Agent builder.
Contains a set of functions to construct a RAG agent, including:
- setting system prompts
- loading data
- adding web search
- setting parameters (e.g. top-k)
Must pass in a cache. This cache will be modified as the agent is built.
"""
def __init__(self, cache: Optional[ParamCache] = None) -> None:
"""Init params."""
self._cache = cache or ParamCache()

@property
def cache(self) -> ParamCache:
"""Cache."""
return self._cache

def create_system_prompt(self, task: str) -> str:
"""Create system prompt for another agent given an input task."""
llm = OpenAI(model="gpt-4-1106-preview")
fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task)
response = llm.chat(fmt_messages)
self._cache.system_prompt = response.message.content

return f"System prompt created: {response.message.content}"


def load_data(
self,
file_names: Optional[List[str]] = None,
urls: Optional[List[str]] = None
) -> str:
"""Load data for a given task.
Only ONE of file_names or urls should be specified.
Args:
file_names (Optional[List[str]]): List of file names to load.
Defaults to None.
urls (Optional[List[str]]): List of urls to load.
Defaults to None.
"""
if file_names is None and urls is None:
raise ValueError("Must specify either file_names or urls.")
elif file_names is not None and urls is not None:
raise ValueError("Must specify only one of file_names or urls.")
elif file_names is not None:
reader = SimpleDirectoryReader(input_files=file_names)
docs = reader.load_data()
elif urls is not None:
from llama_hub.web.simple_web.base import SimpleWebPageReader
# use simple web page reader from llamahub
loader = SimpleWebPageReader()
docs = loader.load_data(urls=urls)
else:
raise ValueError("Must specify either file_names or urls.")

self._cache.docs = docs
return "Data loaded successfully."


# NOTE: unused
def add_web_tool(self) -> None:
"""Add a web tool to enable agent to solve a task."""
# TODO: make this not hardcoded to a web tool
# Set up Metaphor tool
from llama_hub.tools.metaphor.base import MetaphorToolSpec

# TODO: set metaphor API key
metaphor_tool = MetaphorToolSpec(
api_key=os.environ["METAPHOR_API_KEY"],
)
metaphor_tool_list = metaphor_tool.to_tool_list()

self._cache.tools.extend(metaphor_tool_list)
return "Web tool added successfully."

def get_rag_params(self) -> Dict:
"""Get parameters used to configure the RAG pipeline.
Should be called before `set_rag_params` so that the agent is aware of the
schema.
"""
rag_params = self._cache.rag_params
return rag_params.model_dump()


def set_rag_params(self, **rag_params: Dict):
"""Set RAG parameters.
These parameters will then be used to actually initialize the agent.
Should call `get_rag_params` first to get the schema of the input dictionary.
Args:
**rag_params (Dict): dictionary of RAG parameters.
"""
new_dict = self._cache.rag_params.model_dump()
new_dict.update(rag_params)
rag_params_obj = RAGParams(**new_dict)
self._cache.rag_params = rag_params_obj
return "RAG parameters set successfully."


def create_agent(self) -> None:
"""Create an agent.
There are no parameters for this function because all the
functions should have already been called to set up the agent.
"""
rag_params = cast(RAGParams, self._cache.rag_params)
docs = self._cache.docs

# first resolve llm and embedding model
embed_model = resolve_embed_model(rag_params.embed_model)
# llm = resolve_llm(rag_params.llm)
# TODO: use OpenAI for now
llm = OpenAI(model=rag_params.llm)

# first let's index the data with the right parameters
service_context = ServiceContext.from_defaults(
chunk_size=rag_params.chunk_size,
llm=llm,
embed_model=embed_model,
)
vector_index = VectorStoreIndex.from_documents(docs, service_context=service_context)
vector_query_engine = vector_index.as_query_engine(similarity_top_k=rag_params.top_k)
all_tools = []
vector_tool = QueryEngineTool(
query_engine=vector_query_engine,
metadata=ToolMetadata(
name="vector_tool",
description=("Use this tool to answer any user question over any data."),
),
)
all_tools.append(vector_tool)
if rag_params.include_summarization:
summary_index = SummaryIndex.from_documents(docs, service_context=service_context)
summary_query_engine = summary_index.as_query_engine()
summary_tool = QueryEngineTool(
query_engine=summary_query_engine,
metadata=ToolMetadata(
name="summary_tool",
description=("Use this tool for any user questions that ask for a summarization of content"),
),
)
all_tools.append(summary_tool)


# then we add tools
all_tools.extend(self._cache.tools)

# build agent
if self._cache.system_prompt is None:
return "System prompt not set yet. Please set system prompt first."

agent = OpenAIAgent.from_tools(
tools=all_tools,
system_prompt=self._cache.system_prompt,
llm=llm,
verbose=True
)
self._cache.agent = agent
return "Agent created successfully."


####################
#### META Agent ####
####################

RAG_BUILDER_SYS_STR = """\
You are helping to construct an agent given a user-specified task.
You should generally use the tools in this rough order to build the agent.
1) Create system prompt tool: to create the system prompt for the agent.
2) Load in user-specified data (based on file paths they specify).
3) Decide whether or not to add additional tools.
4) Set parameters for the RAG pipeline.
This will be a back and forth conversation with the user. You should
continue asking users if there's anything else they want to do until
they say they're done. To help guide them on the process,
you can give suggestions on parameters they can set based on the tools they
have available (e.g. "Do you want to set the number of documents to retrieve?")
"""


# define agent
@st.cache_resource
def load_meta_agent_and_tools() -> Tuple[OpenAIAgent, RAGAgentBuilder]:
prefix_msgs = [ChatMessage(role="system", content=RAG_BUILDER_SYS_STR)]

# think of this as tools for the agent to use
agent_builder = RAGAgentBuilder()

fns = [
agent_builder.create_system_prompt,
agent_builder.load_data,
# add_web_tool,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent
]
fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns]

builder_agent = OpenAIAgent.from_tools(
tools=fn_tools,
llm=OpenAI(llm="gpt-4-1106-preview"),
prefix_messages=prefix_msgs,
verbose=True,
)
return builder_agent, agent_builder

Loading

0 comments on commit d5a444f

Please sign in to comment.