forked from run-llama/rags
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request run-llama#1 from run-llama/jerry/add_files
add streamlit files
- Loading branch information
Showing
7 changed files
with
943 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
.ipynb_checkpoints | ||
.streamlit | ||
.venv | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import streamlit as st | ||
import os | ||
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key | ||
|
||
from streamlit_pills import pills | ||
|
||
from agent_utils import ( | ||
load_meta_agent_and_tools, | ||
ParamCache | ||
) | ||
|
||
|
||
#################### | ||
#### STREAMLIT ##### | ||
#################### | ||
|
||
|
||
|
||
st.set_page_config(page_title="Build a RAGs bot, powered by LlamaIndex", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None) | ||
st.title("Build a RAGs bot, powered by LlamaIndex 💬🦙") | ||
st.info( | ||
"Use this page to build your RAG bot over your data! " | ||
"Once the agent is finished creating, check out the `RAG Config` and `Generated RAG Agent` pages.", | ||
icon="ℹ️" | ||
) | ||
|
||
#### load builder agent and its tool spec (the agent_builder) | ||
builder_agent, agent_builder = load_meta_agent_and_tools() | ||
|
||
if "builder_agent" not in st.session_state.keys(): | ||
st.session_state.builder_agent = builder_agent | ||
if "agent_builder" not in st.session_state.keys(): | ||
st.session_state.agent_builder = agent_builder | ||
|
||
# add pills | ||
selected = pills( | ||
"Outline your task!", | ||
["I want to analyze this PDF file (data/invoices.pdf)", | ||
"I want to search over my CSV documents." | ||
], clearable=True, index=None | ||
) | ||
|
||
if "messages" not in st.session_state.keys(): # Initialize the chat messages history | ||
st.session_state.messages = [ | ||
{"role": "assistant", "content": "What RAG bot do you want to build?"} | ||
] | ||
|
||
def add_to_message_history(role, content): | ||
message = {"role": role, "content": str(content)} | ||
st.session_state.messages.append(message) # Add response to message history | ||
|
||
for message in st.session_state.messages: # Display the prior chat messages | ||
with st.chat_message(message["role"]): | ||
st.write(message["content"]) | ||
|
||
# handle user input | ||
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history | ||
add_to_message_history("user", prompt) | ||
with st.chat_message("user"): | ||
st.write(prompt) | ||
|
||
# If last message is not from assistant, generate a new response | ||
if st.session_state.messages[-1]["role"] != "assistant": | ||
with st.chat_message("assistant"): | ||
with st.spinner("Thinking..."): | ||
response = st.session_state.builder_agent.chat(prompt) | ||
st.write(str(response)) | ||
add_to_message_history("assistant", response) | ||
|
||
# # check cache | ||
print(st.session_state.agent_builder.cache) | ||
# if "agent" in cache: | ||
# st.session_state.agent = cache["agent"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,308 @@ | ||
from llama_index.llms import OpenAI, ChatMessage | ||
from llama_index.llms.utils import resolve_llm | ||
from pydantic import BaseModel, Field | ||
import os | ||
from llama_index.tools.query_engine import QueryEngineTool | ||
from llama_index.agent import OpenAIAgent | ||
from llama_index import ( | ||
VectorStoreIndex, | ||
SummaryIndex, | ||
ServiceContext, | ||
Document | ||
) | ||
from llama_index.prompts import ChatPromptTemplate | ||
from typing import List, cast, Optional | ||
from llama_index import SimpleDirectoryReader | ||
from llama_index.embeddings.utils import resolve_embed_model | ||
from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool | ||
from typing import Dict, Tuple | ||
import streamlit as st | ||
|
||
|
||
#################### | ||
#### META TOOLS #### | ||
#################### | ||
|
||
|
||
# System prompt tool | ||
GEN_SYS_PROMPT_STR = """\ | ||
Task information is given below. | ||
Given the task, please generate a system prompt for an OpenAI-powered bot to solve this task: | ||
{task} \ | ||
Make sure the system prompt obeys the following requirements: | ||
- Tells the bot to ALWAYS use tools given to solve the task. NEVER give an answer without using a tool. | ||
- Does not reference a specific data source. The data source is implicit in any queries to the bot, | ||
and telling the bot to analyze a specific data source might confuse it given a | ||
user query. | ||
""" | ||
|
||
gen_sys_prompt_messages = [ | ||
ChatMessage( | ||
role="system", | ||
content="You are helping to build a system prompt for another bot.", | ||
), | ||
ChatMessage(role="user", content=GEN_SYS_PROMPT_STR), | ||
] | ||
|
||
GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages) | ||
|
||
|
||
class RAGParams(BaseModel): | ||
"""RAG parameters. | ||
Parameters used to configure a RAG pipeline. | ||
""" | ||
include_summarization: bool = Field(default=False, description="Whether to include summarization in the RAG pipeline.") | ||
top_k: int = Field(default=2, description="Number of documents to retrieve from vector store.") | ||
chunk_size: int = Field(default=1024, description="Chunk size for vector store.") | ||
embed_model: str = Field( | ||
default="default", description="Embedding model to use (default is OpenAI)" | ||
) | ||
llm: str = Field(default="gpt-4-1106-preview", description="LLM to use for summarization.") | ||
|
||
|
||
class ParamCache(BaseModel): | ||
"""Cache for RAG agent builder. | ||
Created a wrapper class around a dict in case we wanted to more explicitly | ||
type different items in the cache. | ||
""" | ||
|
||
# arbitrary types | ||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
system_prompt: Optional[str] = Field(default=None, description="System prompt for RAG agent.") | ||
docs: List[Document] = Field(default_factory=list, description="Documents for RAG agent.") | ||
tools: List = Field(default_factory=list, description="Additional tools for RAG agent (e.g. web)") | ||
rag_params: RAGParams = Field(default_factory=RAGParams, description="RAG parameters for RAG agent.") | ||
agent: Optional[OpenAIAgent] = Field(default=None, description="RAG agent.") | ||
|
||
|
||
|
||
class RAGAgentBuilder: | ||
"""RAG Agent builder. | ||
Contains a set of functions to construct a RAG agent, including: | ||
- setting system prompts | ||
- loading data | ||
- adding web search | ||
- setting parameters (e.g. top-k) | ||
Must pass in a cache. This cache will be modified as the agent is built. | ||
""" | ||
def __init__(self, cache: Optional[ParamCache] = None) -> None: | ||
"""Init params.""" | ||
self._cache = cache or ParamCache() | ||
|
||
@property | ||
def cache(self) -> ParamCache: | ||
"""Cache.""" | ||
return self._cache | ||
|
||
def create_system_prompt(self, task: str) -> str: | ||
"""Create system prompt for another agent given an input task.""" | ||
llm = OpenAI(model="gpt-4-1106-preview") | ||
fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task) | ||
response = llm.chat(fmt_messages) | ||
self._cache.system_prompt = response.message.content | ||
|
||
return f"System prompt created: {response.message.content}" | ||
|
||
|
||
def load_data( | ||
self, | ||
file_names: Optional[List[str]] = None, | ||
urls: Optional[List[str]] = None | ||
) -> str: | ||
"""Load data for a given task. | ||
Only ONE of file_names or urls should be specified. | ||
Args: | ||
file_names (Optional[List[str]]): List of file names to load. | ||
Defaults to None. | ||
urls (Optional[List[str]]): List of urls to load. | ||
Defaults to None. | ||
""" | ||
if file_names is None and urls is None: | ||
raise ValueError("Must specify either file_names or urls.") | ||
elif file_names is not None and urls is not None: | ||
raise ValueError("Must specify only one of file_names or urls.") | ||
elif file_names is not None: | ||
reader = SimpleDirectoryReader(input_files=file_names) | ||
docs = reader.load_data() | ||
elif urls is not None: | ||
from llama_hub.web.simple_web.base import SimpleWebPageReader | ||
# use simple web page reader from llamahub | ||
loader = SimpleWebPageReader() | ||
docs = loader.load_data(urls=urls) | ||
else: | ||
raise ValueError("Must specify either file_names or urls.") | ||
|
||
self._cache.docs = docs | ||
return "Data loaded successfully." | ||
|
||
|
||
# NOTE: unused | ||
def add_web_tool(self) -> None: | ||
"""Add a web tool to enable agent to solve a task.""" | ||
# TODO: make this not hardcoded to a web tool | ||
# Set up Metaphor tool | ||
from llama_hub.tools.metaphor.base import MetaphorToolSpec | ||
|
||
# TODO: set metaphor API key | ||
metaphor_tool = MetaphorToolSpec( | ||
api_key=os.environ["METAPHOR_API_KEY"], | ||
) | ||
metaphor_tool_list = metaphor_tool.to_tool_list() | ||
|
||
self._cache.tools.extend(metaphor_tool_list) | ||
return "Web tool added successfully." | ||
|
||
def get_rag_params(self) -> Dict: | ||
"""Get parameters used to configure the RAG pipeline. | ||
Should be called before `set_rag_params` so that the agent is aware of the | ||
schema. | ||
""" | ||
rag_params = self._cache.rag_params | ||
return rag_params.model_dump() | ||
|
||
|
||
def set_rag_params(self, **rag_params: Dict): | ||
"""Set RAG parameters. | ||
These parameters will then be used to actually initialize the agent. | ||
Should call `get_rag_params` first to get the schema of the input dictionary. | ||
Args: | ||
**rag_params (Dict): dictionary of RAG parameters. | ||
""" | ||
new_dict = self._cache.rag_params.model_dump() | ||
new_dict.update(rag_params) | ||
rag_params_obj = RAGParams(**new_dict) | ||
self._cache.rag_params = rag_params_obj | ||
return "RAG parameters set successfully." | ||
|
||
|
||
def create_agent(self) -> None: | ||
"""Create an agent. | ||
There are no parameters for this function because all the | ||
functions should have already been called to set up the agent. | ||
""" | ||
rag_params = cast(RAGParams, self._cache.rag_params) | ||
docs = self._cache.docs | ||
|
||
# first resolve llm and embedding model | ||
embed_model = resolve_embed_model(rag_params.embed_model) | ||
# llm = resolve_llm(rag_params.llm) | ||
# TODO: use OpenAI for now | ||
llm = OpenAI(model=rag_params.llm) | ||
|
||
# first let's index the data with the right parameters | ||
service_context = ServiceContext.from_defaults( | ||
chunk_size=rag_params.chunk_size, | ||
llm=llm, | ||
embed_model=embed_model, | ||
) | ||
vector_index = VectorStoreIndex.from_documents(docs, service_context=service_context) | ||
vector_query_engine = vector_index.as_query_engine(similarity_top_k=rag_params.top_k) | ||
all_tools = [] | ||
vector_tool = QueryEngineTool( | ||
query_engine=vector_query_engine, | ||
metadata=ToolMetadata( | ||
name="vector_tool", | ||
description=("Use this tool to answer any user question over any data."), | ||
), | ||
) | ||
all_tools.append(vector_tool) | ||
if rag_params.include_summarization: | ||
summary_index = SummaryIndex.from_documents(docs, service_context=service_context) | ||
summary_query_engine = summary_index.as_query_engine() | ||
summary_tool = QueryEngineTool( | ||
query_engine=summary_query_engine, | ||
metadata=ToolMetadata( | ||
name="summary_tool", | ||
description=("Use this tool for any user questions that ask for a summarization of content"), | ||
), | ||
) | ||
all_tools.append(summary_tool) | ||
|
||
|
||
# then we add tools | ||
all_tools.extend(self._cache.tools) | ||
|
||
# build agent | ||
if self._cache.system_prompt is None: | ||
return "System prompt not set yet. Please set system prompt first." | ||
|
||
agent = OpenAIAgent.from_tools( | ||
tools=all_tools, | ||
system_prompt=self._cache.system_prompt, | ||
llm=llm, | ||
verbose=True | ||
) | ||
self._cache.agent = agent | ||
return "Agent created successfully." | ||
|
||
|
||
#################### | ||
#### META Agent #### | ||
#################### | ||
|
||
RAG_BUILDER_SYS_STR = """\ | ||
You are helping to construct an agent given a user-specified task. | ||
You should generally use the tools in this rough order to build the agent. | ||
1) Create system prompt tool: to create the system prompt for the agent. | ||
2) Load in user-specified data (based on file paths they specify). | ||
3) Decide whether or not to add additional tools. | ||
4) Set parameters for the RAG pipeline. | ||
This will be a back and forth conversation with the user. You should | ||
continue asking users if there's anything else they want to do until | ||
they say they're done. To help guide them on the process, | ||
you can give suggestions on parameters they can set based on the tools they | ||
have available (e.g. "Do you want to set the number of documents to retrieve?") | ||
""" | ||
|
||
|
||
# define agent | ||
@st.cache_resource | ||
def load_meta_agent_and_tools() -> Tuple[OpenAIAgent, RAGAgentBuilder]: | ||
prefix_msgs = [ChatMessage(role="system", content=RAG_BUILDER_SYS_STR)] | ||
|
||
# think of this as tools for the agent to use | ||
agent_builder = RAGAgentBuilder() | ||
|
||
fns = [ | ||
agent_builder.create_system_prompt, | ||
agent_builder.load_data, | ||
# add_web_tool, | ||
agent_builder.get_rag_params, | ||
agent_builder.set_rag_params, | ||
agent_builder.create_agent | ||
] | ||
fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns] | ||
|
||
builder_agent = OpenAIAgent.from_tools( | ||
tools=fn_tools, | ||
llm=OpenAI(llm="gpt-4-1106-preview"), | ||
prefix_messages=prefix_msgs, | ||
verbose=True, | ||
) | ||
return builder_agent, agent_builder | ||
|
Oops, something went wrong.