From 49b867909c66448c7479186d49409b507b86fbc7 Mon Sep 17 00:00:00 2001 From: Jerry Liu Date: Sun, 26 Nov 2023 01:27:44 -0800 Subject: [PATCH] upgrade RAGs (#27) --- .github/workflows/lint.yml | 32 + "1_\360\237\217\240_Home.py" | 89 ++- Makefile | 10 + README.md | 4 +- agent_utils.py | 568 ++++++++++++++---- builder_config.py | 3 +- constants.py | 4 + .../2_\342\232\231\357\270\217_RAG_Config.py" | 140 +++-- ...3_\360\237\244\226_Generated_RAG_Agent.py" | 90 +-- pyproject.toml | 65 ++ st_utils.py | 58 ++ tests/__init__.py | 0 12 files changed, 829 insertions(+), 234 deletions(-) create mode 100644 .github/workflows/lint.yml create mode 100644 Makefile create mode 100644 constants.py create mode 100644 pyproject.toml create mode 100644 st_utils.py create mode 100644 tests/__init__.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..ae626e7 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,32 @@ +name: Linting + +on: + push: + branches: + - main + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + strategy: + # You can use PyPy versions in python-version. + # For example, pypy-2.7 and pypy-3.8 + matrix: + python-version: ["3.9"] + poetry-version: [1.5.1] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Run image + uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + - name: Install deps + run: | + poetry install --with dev + - name: Run Linting + run: poetry run make lint diff --git "a/1_\360\237\217\240_Home.py" "b/1_\360\237\217\240_Home.py" index 2a801e5..672413c 100644 --- "a/1_\360\237\217\240_Home.py" +++ "b/1_\360\237\217\240_Home.py" @@ -3,6 +3,11 @@ from agent_utils import ( load_meta_agent_and_tools, + load_agent_ids_from_directory, +) +from st_utils import add_sidebar +from constants import ( + AGENT_CACHE_DIR, ) @@ -11,23 +16,41 @@ #################### -st.set_page_config(page_title="Build a RAGs bot, powered by LlamaIndex", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None) +st.set_page_config( + page_title="Build a RAGs bot, powered by LlamaIndex", + page_icon="🦙", + layout="centered", + initial_sidebar_state="auto", + menu_items=None, +) st.title("Build a RAGs bot, powered by LlamaIndex 💬🦙") st.info( "Use this page to build your RAG bot over your data! " - "Once the agent is finished creating, check out the `RAG Config` and `Generated RAG Agent` pages.", - icon="ℹ️" + "Once the agent is finished creating, check out the `RAG Config` and " + "`Generated RAG Agent` pages.\n" + "To build a new agent, please make sure that 'Create a new agent' is selected.", + icon="ℹ️", ) -# TODO: noodle on this -# with st.sidebar: -# openai_api_key_st = st.text_input("OpenAI API Key (optional, not needed if you filled in secrets.toml)", value="", type="password") -# if st.button("Save"): -# # save api key -# st.session_state.openai_api_key = openai_api_key_st -#### load builder agent and its tool spec (the agent_builder) -builder_agent, agent_builder = load_meta_agent_and_tools() +add_sidebar() + + +if ( + "selected_cache" in st.session_state.keys() + and st.session_state.selected_cache is not None +): + # create builder agent / tools from selected cache + builder_agent, agent_builder = load_meta_agent_and_tools( + cache=st.session_state.selected_cache + ) +else: + # create builder agent / tools from new cache + builder_agent, agent_builder = load_meta_agent_and_tools() + + +st.info(f"Currently building/editing agent: {agent_builder.cache.agent_id}", icon="ℹ️") + if "builder_agent" not in st.session_state.keys(): st.session_state.builder_agent = builder_agent @@ -36,27 +59,34 @@ # add pills selected = pills( - "Outline your task!", - ["I want to analyze this PDF file (data/invoices.pdf)", - "I want to search over my CSV documents." - ], clearable=True, index=None + "Outline your task!", + [ + "I want to analyze this PDF file (data/invoices.pdf)", + "I want to search over my CSV documents.", + ], + clearable=True, + index=None, ) -if "messages" not in st.session_state.keys(): # Initialize the chat messages history +if "messages" not in st.session_state.keys(): # Initialize the chat messages history st.session_state.messages = [ {"role": "assistant", "content": "What RAG bot do you want to build?"} ] -def add_to_message_history(role, content): + +def add_to_message_history(role: str, content: str) -> None: message = {"role": role, "content": str(content)} - st.session_state.messages.append(message) # Add response to message history + st.session_state.messages.append(message) # Add response to message history + -for message in st.session_state.messages: # Display the prior chat messages +for message in st.session_state.messages: # Display the prior chat messages with st.chat_message(message["role"]): st.write(message["content"]) # handle user input -if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history +if prompt := st.chat_input( + "Your question" +): # Prompt for user input and save to chat history add_to_message_history("user", prompt) with st.chat_message("user"): st.write(prompt) @@ -67,9 +97,18 @@ def add_to_message_history(role, content): with st.spinner("Thinking..."): response = st.session_state.builder_agent.chat(prompt) st.write(str(response)) - add_to_message_history("assistant", response) + add_to_message_history("assistant", str(response)) + + # check agent_ids again, if it doesn't match, add to directory and refresh + agent_ids = load_agent_ids_from_directory(str(AGENT_CACHE_DIR)) + # check diff between agent_ids and cur agent ids + diff_ids = list(set(agent_ids) - set(st.session_state.cur_agent_ids)) + if len(diff_ids) > 0: + # clear streamlit cache, to allow you to generate a new agent + st.cache_resource.clear() + + # trigger refresh + st.rerun() -# # check cache -print(st.session_state.agent_builder.cache) -# if "agent" in cache: -# st.session_state.agent = cache["agent"] +else: + pass diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3e84935 --- /dev/null +++ b/Makefile @@ -0,0 +1,10 @@ +.PHONY: format lint + +format: + black . +lint: + mypy . + black --check . + ruff check . +test: + pytest tests \ No newline at end of file diff --git a/README.md b/README.md index d3df21c..59dcfd5 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,10 @@ This project is inspired by [GPTs](https://openai.com/blog/introducing-gpts), la ## Installation and Setup -Clone this project, go into the `rags` project folder. +Clone this project, go into the `rags` project folder. We recommend creating a virtual env for dependencies (`python3 -m venv .venv`). ``` -pip install -r requirements.txt +poetry install --with dev ``` By default, we use OpenAI for both the builder agent as well as the generated RAG agent. diff --git a/agent_utils.py b/agent_utils.py index a61e409..d65f0a8 100644 --- a/agent_utils.py +++ b/agent_utils.py @@ -3,14 +3,15 @@ from llama_index.llms.utils import resolve_llm from pydantic import BaseModel, Field import os -from llama_index.tools.query_engine import QueryEngineTool from llama_index.agent import OpenAIAgent, ReActAgent from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER from llama_index import ( VectorStoreIndex, SummaryIndex, ServiceContext, - Document + StorageContext, + Document, + load_index_from_storage, ) from llama_index.prompts import ChatPromptTemplate from typing import List, cast, Optional @@ -18,28 +19,32 @@ from llama_index.embeddings.utils import resolve_embed_model from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool from llama_index.agent.types import BaseAgent +from llama_index.chat_engine.types import BaseChatEngine from llama_index.agent.react.formatter import ReActChatFormatter from llama_index.llms.openai_utils import is_function_calling_model from llama_index.chat_engine import CondensePlusContextChatEngine from builder_config import BUILDER_LLM -from typing import Dict, Tuple, Any +from typing import Dict, Tuple, Any, Callable import streamlit as st from pathlib import Path import json +import uuid +from constants import AGENT_CACHE_DIR +import shutil -def _resolve_llm(llm: str) -> LLM: +def _resolve_llm(llm_str: str) -> LLM: """Resolve LLM.""" # TODO: make this less hardcoded with if-else statements # see if there's a prefix # - if there isn't, assume it's an OpenAI model # - if there is, resolve it - tokens = llm.split(":") + tokens = llm_str.split(":") if len(tokens) == 1: os.environ["OPENAI_API_KEY"] = st.secrets.openai_key - llm = OpenAI(model=llm) + llm: LLM = OpenAI(model=llm_str) elif tokens[0] == "local": - llm = resolve_llm(llm) + llm = resolve_llm(llm_str) elif tokens[0] == "openai": os.environ["OPENAI_API_KEY"] = st.secrets.openai_key llm = OpenAI(model=tokens[1]) @@ -50,7 +55,7 @@ def _resolve_llm(llm: str) -> LLM: os.environ["REPLICATE_API_KEY"] = st.secrets.replicate_key llm = Replicate(model=tokens[1]) else: - raise ValueError(f"LLM {llm} not recognized.") + raise ValueError(f"LLM {llm_str} not recognized.") return llm @@ -63,14 +68,17 @@ def _resolve_llm(llm: str) -> LLM: GEN_SYS_PROMPT_STR = """\ Task information is given below. -Given the task, please generate a system prompt for an OpenAI-powered bot to solve this task: +Given the task, please generate a system prompt for an OpenAI-powered bot \ +to solve this task: {task} \ Make sure the system prompt obeys the following requirements: -- Tells the bot to ALWAYS use tools given to solve the task. NEVER give an answer without using a tool. -- Does not reference a specific data source. The data source is implicit in any queries to the bot, - and telling the bot to analyze a specific data source might confuse it given a - user query. +- Tells the bot to ALWAYS use tools given to solve the task. \ +NEVER give an answer without using a tool. +- Does not reference a specific data source. \ +The data source is implicit in any queries to the bot, \ +and telling the bot to analyze a specific data source might confuse it given a \ +user query. """ @@ -85,49 +93,196 @@ def _resolve_llm(llm: str) -> LLM: GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages) +class RAGParams(BaseModel): + """RAG parameters. + + Parameters used to configure a RAG pipeline. + + """ + + include_summarization: bool = Field( + default=False, + description=( + "Whether to include summarization in the RAG pipeline. (only for GPT-4)" + ), + ) + top_k: int = Field( + default=2, description="Number of documents to retrieve from vector store." + ) + chunk_size: int = Field(default=1024, description="Chunk size for vector store.") + embed_model: str = Field( + default="default", description="Embedding model to use (default is OpenAI)" + ) + llm: str = Field( + default="gpt-4-1106-preview", description="LLM to use for summarization." + ) + + +def load_data( + file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None +) -> List[Document]: + """Load data.""" + file_names = file_names or [] + urls = urls or [] + if not file_names and not urls: + raise ValueError("Must specify either file_names or urls.") + elif file_names and urls: + raise ValueError("Must specify only one of file_names or urls.") + elif file_names: + reader = SimpleDirectoryReader(input_files=file_names) + docs = reader.load_data() + elif urls: + from llama_hub.web.simple_web.base import SimpleWebPageReader + + # use simple web page reader from llamahub + loader = SimpleWebPageReader() + docs = loader.load_data(urls=urls) + else: + raise ValueError("Must specify either file_names or urls.") + + return docs + + def load_agent( - tools: List, - llm: LLM, + tools: List, + llm: LLM, system_prompt: str, extra_kwargs: Optional[Dict] = None, - **kwargs: Any -) -> BaseAgent: + **kwargs: Any, +) -> BaseChatEngine: """Load agent.""" extra_kwargs = extra_kwargs or {} if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): # get OpenAI Agent - agent = OpenAIAgent.from_tools( - tools=tools, - llm=llm, - system_prompt=system_prompt, - **kwargs + agent: BaseChatEngine = OpenAIAgent.from_tools( + tools=tools, llm=llm, system_prompt=system_prompt, **kwargs ) else: if "vector_index" not in extra_kwargs: - raise ValueError("Must pass in vector index for CondensePlusContextChatEngine.") + raise ValueError( + "Must pass in vector index for CondensePlusContextChatEngine." + ) vector_index = cast(VectorStoreIndex, extra_kwargs["vector_index"]) rag_params = cast(RAGParams, extra_kwargs["rag_params"]) # use condense + context chat engine agent = CondensePlusContextChatEngine.from_defaults( vector_index.as_retriever(similarity_top_k=rag_params.top_k), ) - + return agent -class RAGParams(BaseModel): - """RAG parameters. +def load_meta_agent( + tools: List, + llm: LLM, + system_prompt: str, + extra_kwargs: Optional[Dict] = None, + **kwargs: Any, +) -> BaseAgent: + """Load meta agent. + + TODO: consolidate with load_agent. + + The meta-agent *has* to perform tool-use. - Parameters used to configure a RAG pipeline. - """ - include_summarization: bool = Field(default=False, description="Whether to include summarization in the RAG pipeline. (only for GPT-4)") - top_k: int = Field(default=2, description="Number of documents to retrieve from vector store.") - chunk_size: int = Field(default=1024, description="Chunk size for vector store.") - embed_model: str = Field( - default="default", description="Embedding model to use (default is OpenAI)" + extra_kwargs = extra_kwargs or {} + if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): + # get OpenAI Agent + agent: BaseAgent = OpenAIAgent.from_tools( + tools=tools, llm=llm, system_prompt=system_prompt, **kwargs + ) + else: + agent = ReActAgent.from_tools( + tools=tools, + llm=llm, + react_chat_formatter=ReActChatFormatter( + system_header=system_prompt + "\n" + REACT_CHAT_SYSTEM_HEADER, + ), + **kwargs, + ) + + return agent + + +def construct_agent( + system_prompt: str, + rag_params: RAGParams, + docs: List[Document], + vector_index: Optional[VectorStoreIndex] = None, + additional_tools: Optional[List] = None, +) -> Tuple[BaseChatEngine, Dict]: + """Construct agent from docs / parameters / indices.""" + extra_info = {} + additional_tools = additional_tools or [] + + # first resolve llm and embedding model + embed_model = resolve_embed_model(rag_params.embed_model) + # llm = resolve_llm(rag_params.llm) + # TODO: use OpenAI for now + # llm = OpenAI(model=rag_params.llm) + llm = _resolve_llm(rag_params.llm) + + # first let's index the data with the right parameters + service_context = ServiceContext.from_defaults( + chunk_size=rag_params.chunk_size, + llm=llm, + embed_model=embed_model, ) - llm: str = Field(default="gpt-4-1106-preview", description="LLM to use for summarization.") + + if vector_index is None: + vector_index = VectorStoreIndex.from_documents( + docs, service_context=service_context + ) + else: + pass + + extra_info["vector_index"] = vector_index + + vector_query_engine = vector_index.as_query_engine( + similarity_top_k=rag_params.top_k + ) + all_tools = [] + vector_tool = QueryEngineTool( + query_engine=vector_query_engine, + metadata=ToolMetadata( + name="vector_tool", + description=("Use this tool to answer any user question over any data."), + ), + ) + all_tools.append(vector_tool) + if rag_params.include_summarization: + summary_index = SummaryIndex.from_documents( + docs, service_context=service_context + ) + summary_query_engine = summary_index.as_query_engine() + summary_tool = QueryEngineTool( + query_engine=summary_query_engine, + metadata=ToolMetadata( + name="summary_tool", + description=( + "Use this tool for any user questions that ask " + "for a summarization of content" + ), + ), + ) + all_tools.append(summary_tool) + + # then we add tools + all_tools.extend(additional_tools) + + # build agent + if system_prompt is None: + return "System prompt not set yet. Please set system prompt first." + + agent = load_agent( + all_tools, + llm=llm, + system_prompt=system_prompt, + verbose=True, + extra_kwargs={"vector_index": vector_index, "rag_params": rag_params}, + ) + return agent, extra_info class ParamCache(BaseModel): @@ -135,20 +290,163 @@ class ParamCache(BaseModel): Created a wrapper class around a dict in case we wanted to more explicitly type different items in the cache. - + """ # arbitrary types class Config: arbitrary_types_allowed = True - system_prompt: Optional[str] = Field(default=None, description="System prompt for RAG agent.") - file_paths: List[str] = Field(default_factory=list, description="File paths for RAG agent.") - docs: List[Document] = Field(default_factory=list, description="Documents for RAG agent.") - tools: List = Field(default_factory=list, description="Additional tools for RAG agent (e.g. web)") - rag_params: RAGParams = Field(default_factory=RAGParams, description="RAG parameters for RAG agent.") - agent: Optional[OpenAIAgent] = Field(default=None, description="RAG agent.") + # system prompt + system_prompt: Optional[str] = Field( + default=None, description="System prompt for RAG agent." + ) + # data + file_names: List[str] = Field( + default_factory=list, description="File names as data source (if specified)" + ) + urls: List[str] = Field( + default_factory=list, description="URLs as data source (if specified)" + ) + docs: List = Field(default_factory=list, description="Documents for RAG agent.") + # tools + tools: List = Field( + default_factory=list, description="Additional tools for RAG agent (e.g. web)" + ) + # RAG params + rag_params: RAGParams = Field( + default_factory=RAGParams, description="RAG parameters for RAG agent." + ) + + # agent params + vector_index: Optional[VectorStoreIndex] = Field( + default=None, description="Vector index for RAG agent." + ) + agent_id: str = Field( + default_factory=lambda: f"Agent_{str(uuid.uuid4())}", + description="Agent ID for RAG agent.", + ) + agent: Optional[BaseChatEngine] = Field(default=None, description="RAG agent.") + + def save_to_disk(self, save_dir: str) -> None: + """Save cache to disk.""" + # NOTE: more complex than just calling dict() because we want to + # only store serializable fields and be space-efficient + + dict_to_serialize = { + "system_prompt": self.system_prompt, + "file_names": self.file_names, + "urls": self.urls, + # TODO: figure out tools + # "tools": [], + "rag_params": self.rag_params.dict(), + "agent_id": self.agent_id, + } + # store the vector store within the agent + if self.vector_index is None: + raise ValueError("Must specify vector index in order to save.") + self.vector_index.storage_context.persist(Path(save_dir) / "storage") + + # if save_path directories don't exist, create it + if not Path(save_dir).exists(): + Path(save_dir).mkdir(parents=True) + with open(Path(save_dir) / "cache.json", "w") as f: + json.dump(dict_to_serialize, f) + + @classmethod + def load_from_disk( + cls, + save_dir: str, + ) -> "ParamCache": + """Load cache from disk.""" + storage_context = StorageContext.from_defaults( + persist_dir=str(Path(save_dir) / "storage") + ) + vector_index = cast(VectorStoreIndex, load_index_from_storage(storage_context)) + + with open(Path(save_dir) / "cache.json", "r") as f: + cache_dict = json.load(f) + # replace rag params with RAGParams object + cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"]) + + # add in the missing fields + # load docs + cache_dict["docs"] = load_data( + file_names=cache_dict["file_names"], urls=cache_dict["urls"] + ) + # load agent from index + agent, _ = construct_agent( + cache_dict["system_prompt"], + cache_dict["rag_params"], + cache_dict["docs"], + vector_index=vector_index, + # TODO: figure out tools + ) + cache_dict["vector_index"] = vector_index + cache_dict["agent"] = agent + + return cls(**cache_dict) + + +def add_agent_id_to_directory(dir: str, agent_id: str) -> None: + """Save agent id to directory.""" + full_path = Path(dir) / "agent_ids.json" + if not full_path.exists(): + with open(full_path, "w") as f: + json.dump({"agent_ids": [agent_id]}, f) + else: + with open(full_path, "r") as f: + agent_ids = json.load(f)["agent_ids"] + if agent_id in agent_ids: + raise ValueError(f"Agent id {agent_id} already exists.") + agent_ids_set = set(agent_ids) + agent_ids_set.add(agent_id) + with open(full_path, "w") as f: + json.dump({"agent_ids": list(agent_ids_set)}, f) + + +def load_agent_ids_from_directory(dir: str) -> List[str]: + """Load agent ids file.""" + full_path = Path(dir) / "agent_ids.json" + if not full_path.exists(): + return [] + with open(full_path, "r") as f: + agent_ids = json.load(f)["agent_ids"] + + return agent_ids + + +def load_cache_from_directory( + dir: str, + agent_id: str, +) -> ParamCache: + """Load cache from directory.""" + full_path = Path(dir) / f"{agent_id}" + if not full_path.exists(): + raise ValueError(f"Cache for agent {agent_id} does not exist.") + cache = ParamCache.load_from_disk(str(full_path)) + return cache + + +def remove_agent_from_directory( + dir: str, + agent_id: str, +) -> None: + """Remove agent from directory.""" + + # modify / resave agent_ids + agent_ids = load_agent_ids_from_directory(dir) + new_agent_ids = [id for id in agent_ids if id != agent_id] + full_path = Path(dir) / "agent_ids.json" + with open(full_path, "w") as f: + json.dump({"agent_ids": new_agent_ids}, f) + + # remove agent cache + full_path = Path(dir) / f"{agent_id}" + if full_path.exists(): + # recursive delete + shutil.rmtree(full_path) class RAGAgentBuilder: @@ -161,11 +459,15 @@ class RAGAgentBuilder: - setting parameters (e.g. top-k) Must pass in a cache. This cache will be modified as the agent is built. - + """ - def __init__(self, cache: Optional[ParamCache] = None) -> None: + + def __init__( + self, cache: Optional[ParamCache] = None, cache_dir: Optional[str] = None + ) -> None: """Init params.""" self._cache = cache or ParamCache() + self._cache_dir = cache_dir or AGENT_CACHE_DIR @property def cache(self) -> ParamCache: @@ -181,11 +483,8 @@ def create_system_prompt(self, task: str) -> str: return f"System prompt created: {response.message.content}" - def load_data( - self, - file_names: Optional[List[str]] = None, - urls: Optional[List[str]] = None + self, file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None ) -> str: """Load data for a given task. @@ -196,32 +495,18 @@ def load_data( Defaults to None. urls (Optional[List[str]]): List of urls to load. Defaults to None. - + """ - if file_names is None and urls is None: - raise ValueError("Must specify either file_names or urls.") - elif file_names is not None and urls is not None: - raise ValueError("Must specify only one of file_names or urls.") - elif file_names is not None: - reader = SimpleDirectoryReader(input_files=file_names) - docs = reader.load_data() - file_paths = file_names - elif urls is not None: - from llama_hub.web.simple_web.base import SimpleWebPageReader - # use simple web page reader from llamahub - loader = SimpleWebPageReader() - docs = loader.load_data(urls=urls) - file_paths = urls - else: - raise ValueError("Must specify either file_names or urls.") - + file_names = file_names or [] + urls = urls or [] + docs = load_data(file_names=file_names, urls=urls) self._cache.docs = docs - self._cache.file_paths = file_paths + self._cache.file_names = file_names + self._cache.urls = urls return "Data loaded successfully." - # NOTE: unused - def add_web_tool(self) -> None: + def add_web_tool(self) -> str: """Add a web tool to enable agent to solve a task.""" # TODO: make this not hardcoded to a web tool # Set up Metaphor tool @@ -241,21 +526,20 @@ def get_rag_params(self) -> Dict: Should be called before `set_rag_params` so that the agent is aware of the schema. - + """ rag_params = self._cache.rag_params return rag_params.dict() - - def set_rag_params(self, **rag_params: Dict): + def set_rag_params(self, **rag_params: Dict) -> str: """Set RAG parameters. These parameters will then be used to actually initialize the agent. Should call `get_rag_params` first to get the schema of the input dictionary. Args: - **rag_params (Dict): dictionary of RAG parameters. - + **rag_params (Dict): dictionary of RAG parameters. + """ new_dict = self._cache.rag_params.dict() new_dict.update(rag_params) @@ -263,69 +547,85 @@ def set_rag_params(self, **rag_params: Dict): self._cache.rag_params = rag_params_obj return "RAG parameters set successfully." - - def create_agent(self) -> None: + def create_agent(self, agent_id: Optional[str] = None) -> str: """Create an agent. There are no parameters for this function because all the functions should have already been called to set up the agent. - - """ - rag_params = cast(RAGParams, self._cache.rag_params) - docs = self._cache.docs - - # first resolve llm and embedding model - embed_model = resolve_embed_model(rag_params.embed_model) - # llm = resolve_llm(rag_params.llm) - # TODO: use OpenAI for now - # llm = OpenAI(model=rag_params.llm) - llm = _resolve_llm(rag_params.llm) - - # first let's index the data with the right parameters - service_context = ServiceContext.from_defaults( - chunk_size=rag_params.chunk_size, - llm=llm, - embed_model=embed_model, - ) - vector_index = VectorStoreIndex.from_documents(docs, service_context=service_context) - vector_query_engine = vector_index.as_query_engine(similarity_top_k=rag_params.top_k) - all_tools = [] - vector_tool = QueryEngineTool( - query_engine=vector_query_engine, - metadata=ToolMetadata( - name="vector_tool", - description=("Use this tool to answer any user question over any data."), - ), - ) - all_tools.append(vector_tool) - if rag_params.include_summarization: - summary_index = SummaryIndex.from_documents(docs, service_context=service_context) - summary_query_engine = summary_index.as_query_engine() - summary_tool = QueryEngineTool( - query_engine=summary_query_engine, - metadata=ToolMetadata( - name="summary_tool", - description=("Use this tool for any user questions that ask for a summarization of content"), - ), - ) - all_tools.append(summary_tool) - - - # then we add tools - all_tools.extend(self._cache.tools) - # build agent + """ if self._cache.system_prompt is None: - return "System prompt not set yet. Please set system prompt first." + raise ValueError("Must set system prompt before creating agent.") - agent = load_agent( - all_tools, llm=llm, system_prompt=self._cache.system_prompt, verbose=True, - extra_kwargs={"vector_index": vector_index, "rag_params": rag_params} + agent, extra_info = construct_agent( + cast(str, self._cache.system_prompt), + cast(RAGParams, self._cache.rag_params), + self._cache.docs, + additional_tools=self._cache.tools, ) + # if agent_id not specified, randomly generate one + agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}" + self._cache.vector_index = extra_info["vector_index"] + self._cache.agent_id = agent_id self._cache.agent = agent + + # save the cache to disk + agent_cache_path = f"{self._cache_dir}/{agent_id}" + self._cache.save_to_disk(agent_cache_path) + # save to agent ids + add_agent_id_to_directory(str(self._cache_dir), agent_id) + return "Agent created successfully." + def update_agent( + self, + agent_id: str, + system_prompt: Optional[str] = None, + include_summarization: Optional[bool] = None, + top_k: Optional[int] = None, + chunk_size: Optional[int] = None, + embed_model: Optional[str] = None, + llm: Optional[str] = None, + ) -> None: + """Update agent. + + Delete old agent by ID and create a new one. + Optionally update the system prompt and RAG parameters. + + NOTE: Currently is manually called, not meant for agent use. + + """ + # remove saved agent from directory, since we'll be re-saving + remove_agent_from_directory(str(AGENT_CACHE_DIR), self.cache.agent_id) + + # set agent id + self.cache.agent_id = agent_id + + # set system prompt + if system_prompt is not None: + self.cache.system_prompt = system_prompt + # get agent_builder + # We call set_rag_params and create_agent, which will + # update the cache + # TODO: decouple functions from tool functions exposed to the agent + + rag_params_dict: Dict[str, Any] = {} + if include_summarization is not None: + rag_params_dict["include_summarization"] = include_summarization + if top_k is not None: + rag_params_dict["top_k"] = top_k + if chunk_size is not None: + rag_params_dict["chunk_size"] = chunk_size + if embed_model is not None: + rag_params_dict["embed_model"] = embed_model + if llm is not None: + rag_params_dict["llm"] = llm + + self.set_rag_params(**rag_params_dict) + # this will update the agent in the cache + self.create_agent() + #################### #### META Agent #### @@ -339,6 +639,7 @@ def create_agent(self) -> None: 2) Load in user-specified data (based on file paths they specify). 3) Decide whether or not to add additional tools. 4) Set parameters for the RAG pipeline. +5) Build the agent This will be a back and forth conversation with the user. You should continue asking users if there's anything else they want to do until @@ -355,25 +656,26 @@ def create_agent(self) -> None: # define agent -@st.cache_resource -def load_meta_agent_and_tools() -> Tuple[OpenAIAgent, RAGAgentBuilder]: +# @st.cache_resource +def load_meta_agent_and_tools( + cache: Optional[ParamCache] = None, +) -> Tuple[BaseAgent, RAGAgentBuilder]: # think of this as tools for the agent to use - agent_builder = RAGAgentBuilder() + agent_builder = RAGAgentBuilder(cache) - fns = [ - agent_builder.create_system_prompt, - agent_builder.load_data, + fns: List[Callable] = [ + agent_builder.create_system_prompt, + agent_builder.load_data, # add_web_tool, agent_builder.get_rag_params, agent_builder.set_rag_params, - agent_builder.create_agent + agent_builder.create_agent, ] fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns] - builder_agent = load_agent( + builder_agent = load_meta_agent( fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True ) return builder_agent, agent_builder - \ No newline at end of file diff --git a/builder_config.py b/builder_config.py index e383890..07bd987 100644 --- a/builder_config.py +++ b/builder_config.py @@ -7,6 +7,7 @@ ## OpenAI from llama_index.llms import OpenAI + # set OpenAI Key - use Streamlit secrets os.environ["OPENAI_API_KEY"] = st.secrets.openai_key # load LLM @@ -16,4 +17,4 @@ # from llama_index.llms import Anthropic # # set Anthropic key # os.environ["ANTHROPIC_API_KEY"] = st.secrets.anthropic_key -# BUILDER_LLM = Anthropic() \ No newline at end of file +# BUILDER_LLM = Anthropic() diff --git a/constants.py b/constants.py new file mode 100644 index 0000000..a6d4599 --- /dev/null +++ b/constants.py @@ -0,0 +1,4 @@ +from pathlib import Path + +AGENT_CACHE_DIR = Path(__file__).parent / "cache" / "agents" +MESSAGES_CACHE_DIR = Path(__file__).parent / "cache" / "messages" diff --git "a/pages/2_\342\232\231\357\270\217_RAG_Config.py" "b/pages/2_\342\232\231\357\270\217_RAG_Config.py" index 9550c26..e2e8d70 100644 --- "a/pages/2_\342\232\231\357\270\217_RAG_Config.py" +++ "b/pages/2_\342\232\231\357\270\217_RAG_Config.py" @@ -1,13 +1,16 @@ """Streamlit page showing builder config.""" import streamlit as st -import openai -from streamlit_pills import pills -from typing import cast +from typing import cast, Optional from agent_utils import ( RAGParams, RAGAgentBuilder, + ParamCache, + remove_agent_from_directory, ) +from st_utils import update_selected_agent_with_id +from constants import AGENT_CACHE_DIR +from st_utils import add_sidebar #################### @@ -15,53 +18,116 @@ #################### +def update_agent() -> None: + """Update agent.""" + if ( + "config_agent_builder" in st.session_state.keys() + and st.session_state.config_agent_builder is not None + ): + agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder) + ### Update the agent + agent_builder.update_agent( + st.session_state.agent_id_st, + system_prompt=st.session_state.sys_prompt_st, + include_summarization=st.session_state.include_summarization_st, + top_k=st.session_state.top_k_st, + chunk_size=st.session_state.chunk_size_st, + embed_model=st.session_state.embed_model_st, + llm=st.session_state.llm_st, + ) -st.set_page_config(page_title="RAG Pipeline Config", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None) -st.title("RAG Pipeline Config") -st.info( - "This is generated by the builder in the above section.", icon="ℹ️" + # Update Radio Buttons: update selected agent to the new id + update_selected_agent_with_id(agent_builder.cache.agent_id) + else: + raise ValueError("Agent builder is None. Cannot update agent.") + + +def delete_agent() -> None: + """Delete agent.""" + if ( + "config_agent_builder" in st.session_state.keys() + and st.session_state.config_agent_builder is not None + ): + agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder) + ### Delete agent + # remove saved agent from directory + remove_agent_from_directory(str(AGENT_CACHE_DIR), agent_builder.cache.agent_id) + # Update Radio Buttons: update selected agent to the new id + update_selected_agent_with_id(None) + else: + raise ValueError("Agent builder is None. Cannot delete agent.") + + +st.set_page_config( + page_title="RAG Pipeline Config", + page_icon="🦙", + layout="centered", + initial_sidebar_state="auto", + menu_items=None, ) +st.title("RAG Pipeline Config") +add_sidebar() -if "agent_builder" in st.session_state.keys(): +# first, pick the cache: this is preloaded from an existing agent, +# or is part of the current one being created +if ( + "selected_cache" in st.session_state.keys() + and st.session_state.selected_cache is not None +): + cache = cast(ParamCache, st.session_state.selected_cache) + agent_builder: Optional[RAGAgentBuilder] = RAGAgentBuilder(cache) +elif "agent_builder" in st.session_state.keys(): agent_builder = cast(RAGAgentBuilder, st.session_state.agent_builder) +else: + agent_builder = None + +# set as session state +st.session_state.config_agent_builder = agent_builder + +if agent_builder is not None: + + st.info(f"Viewing config for agent: {agent_builder.cache.agent_id}", icon="ℹ️") + + agent_id_st = st.text_input( + "Agent ID", value=agent_builder.cache.agent_id, key="agent_id_st" + ) + if agent_builder.cache.system_prompt is None: system_prompt = "" else: system_prompt = agent_builder.cache.system_prompt - sys_prompt_st = st.text_area("System Prompt", value=system_prompt) + sys_prompt_st = st.text_area( + "System Prompt", value=system_prompt, key="sys_prompt_st" + ) rag_params = cast(RAGParams, agent_builder.cache.rag_params) - file_paths = st.text_input( - "File/URL paths (not editable)", - value=",".join(agent_builder.cache.file_paths), - disabled=True + file_names = st.text_input( + "File names (not editable)", + value=",".join(agent_builder.cache.file_names), + disabled=True, ) - include_summarization_st = st.checkbox("Include Summarization (only works for GPT-4)", value=rag_params.include_summarization) - top_k_st = st.number_input("Top K", value=rag_params.top_k) - chunk_size_st = st.number_input("Chunk Size", value=rag_params.chunk_size) - embed_model_st = st.text_input("Embed Model", value=rag_params.embed_model) - llm_st = st.text_input("LLM", value=rag_params.llm) + urls = st.text_input( + "URLs (not editable)", value=",".join(agent_builder.cache.urls), disabled=True + ) + include_summarization_st = st.checkbox( + "Include Summarization (only works for GPT-4)", + value=rag_params.include_summarization, + key="include_summarization_st", + ) + top_k_st = st.number_input("Top K", value=rag_params.top_k, key="top_k_st") + chunk_size_st = st.number_input( + "Chunk Size", value=rag_params.chunk_size, key="chunk_size_st" + ) + embed_model_st = st.text_input( + "Embed Model", value=rag_params.embed_model, key="embed_model_st" + ) + llm_st = st.text_input("LLM", value=rag_params.llm, key="llm_st") if agent_builder.cache.agent is not None: - if st.button("Update Agent"): - # update the agent - agent_builder.cache.system_prompt = sys_prompt_st - # get agent_builder - # We call set_rag_params and create_agent, which will - # update the cache - agent_builder = cast(RAGAgentBuilder, st.session_state.agent_builder) - - # TODO: decouple functions from tool functions exposed to the agent - agent_builder.set_rag_params( - include_summarization=include_summarization_st, - top_k=top_k_st, - chunk_size=chunk_size_st, - embed_model=embed_model_st, - llm=llm_st, - ) - # this will update the agent in the cache - agent_builder.create_agent() + st.button("Update Agent", on_click=update_agent) + st.button(":red[Delete Agent]", on_click=delete_agent) else: # show text saying "agent not created" st.info("Agent not created. Please create an agent in the above section.") + else: - st.info("agent builder not created yet. Please describe your task in the above section.") + st.info("No agent builder found. Please create an agent in the above section.") diff --git "a/pages/3_\360\237\244\226_Generated_RAG_Agent.py" "b/pages/3_\360\237\244\226_Generated_RAG_Agent.py" index 7e98e99..1afbdbd 100644 --- "a/pages/3_\360\237\244\226_Generated_RAG_Agent.py" +++ "b/pages/3_\360\237\244\226_Generated_RAG_Agent.py" @@ -1,11 +1,8 @@ """Streamlit page showing builder config.""" import streamlit as st -from typing import cast -from agent_utils import ( - RAGAgentBuilder, -) - -from streamlit_pills import pills +from typing import cast, Optional +from agent_utils import RAGAgentBuilder, ParamCache +from st_utils import add_sidebar #################### @@ -13,46 +10,67 @@ #################### - -st.set_page_config(page_title="Generated RAG Agent", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None) -st.title("Generated RAG Agent") -st.info( - "This is generated by the builder in the above section.", icon="ℹ️" +st.set_page_config( + page_title="Generated RAG Agent", + page_icon="🦙", + layout="centered", + initial_sidebar_state="auto", + menu_items=None, ) +st.title("Generated RAG Agent") + +add_sidebar() -if "agent_messages" not in st.session_state.keys(): # Initialize the chat messages history +if ( + "agent_messages" not in st.session_state.keys() +): # Initialize the chat messages history st.session_state.agent_messages = [ {"role": "assistant", "content": "Ask me a question!"} ] -def add_to_message_history(role, content): + +def add_to_message_history(role: str, content: str) -> None: message = {"role": role, "content": str(content)} - st.session_state.agent_messages.append(message) # Add response to message history + st.session_state.agent_messages.append(message) # Add response to message history +# first, pick the cache: this is preloaded from an existing agent, +# or is part of the current one being created agent = None -if "agent_builder" in st.session_state.keys(): +if ( + "selected_cache" in st.session_state.keys() + and st.session_state.selected_cache is not None +): + cache: Optional[ParamCache] = cast(ParamCache, st.session_state.selected_cache) +elif "agent_builder" in st.session_state.keys(): agent_builder = cast(RAGAgentBuilder, st.session_state.agent_builder) - if agent_builder.cache.agent is not None: - agent = agent_builder.cache.agent - for message in st.session_state.agent_messages: # Display the prior chat messages - with st.chat_message(message["role"]): - st.write(message["content"]) - - # don't process selected for now - if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history - add_to_message_history("user", prompt) - with st.chat_message("user"): - st.write(prompt) - - # If last message is not from assistant, generate a new response - if st.session_state.agent_messages[-1]["role"] != "assistant": - with st.chat_message("assistant"): - with st.spinner("Thinking..."): - response = agent.chat(prompt) - st.write(str(response)) - add_to_message_history("assistant", response) - else: - st.info("Agent not created. Please create an agent in the above section.") + cache = agent_builder.cache +else: + cache = None + st.info("Agent not created. Please create an agent in the above section.") + +# if agent is created, then we can chat with it +if cache is not None and cache.agent is not None: + st.info(f"Viewing config for agent: {cache.agent_id}", icon="ℹ️") + agent = cache.agent + for message in st.session_state.agent_messages: # Display the prior chat messages + with st.chat_message(message["role"]): + st.write(message["content"]) + + # don't process selected for now + if prompt := st.chat_input( + "Your question" + ): # Prompt for user input and save to chat history + add_to_message_history("user", prompt) + with st.chat_message("user"): + st.write(prompt) + + # If last message is not from assistant, generate a new response + if st.session_state.agent_messages[-1]["role"] != "assistant": + with st.chat_message("assistant"): + with st.spinner("Thinking..."): + response = agent.chat(str(prompt)) + st.write(str(response)) + add_to_message_history("assistant", str(response)) else: st.info("Agent not created. Please create an agent in the above section.") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6870398 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,65 @@ +[tool.poetry] +name = "rags" +version = "0.0.2" +description = "Build RAG with natural language." +authors = ["Jerry Liu"] +# New attributes +license = "MIT" +readme = "README.md" +homepage = "https://docs.llamaindex.ai/en/latest/" +repository = "https://github.com/run-llama/rags" +keywords = ["llama-index", "rags"] +include = [ + "LICENSE", +] + +[tool.poetry.dependencies] +python = ">=3.8.1,<3.12,!=3.9.7" +streamlit = "1.28.0" +streamlit-pills = "0.3.0" +llama-index = "0.9.7" +llama-hub = "0.0.44" +# NOTE: this is due to a trivial dependency in the web tool, will refactor +langchain = "0.0.305" +pypdf = "3.17.1" + +[tool.poetry.dev-dependencies] +# pytest = "7.2.1" +# pytest-dotenv = "0.5.2" +# pytest_httpserver = "1.0.8" +# pytest-mock = "3.11.1" +typing-inspect = "0.8.0" +typing_extensions = "^4.5.0" +types-requests = "2.28.11.8" +black = "22.12.0" +isort = "5.11.4" +pytest-asyncio = "^0.21.1" +ruff = "0.0.285" +mypy = "0.991" + +[build-system] +requires = ["poetry>=0.12", "poetry-core>=1.0.0"] +build-backend = "poetry.masonry.api" + +[tool.mypy] +disallow_untyped_defs = true +ignore_missing_imports = true +exclude = ["notebooks", "build", "examples"] + +[tool.ruff] +# Allow lines to be as long as 80 characters. +# TODO: it should be removed, but we need to fix the entire code first. +line-length = 88 +exclude = [ + ".venv", + "__pycache__", + ".ipynb_checkpoints", + ".mypy_cache", + ".ruff_cache", + "examples", + "notebooks", + ".git" +] + +[tool.ruff.per-file-ignores] +"base.py" = ["E402", "F811", "E501"] diff --git a/st_utils.py b/st_utils.py new file mode 100644 index 0000000..e7f4cd7 --- /dev/null +++ b/st_utils.py @@ -0,0 +1,58 @@ +"""Streamlit utils.""" +from agent_utils import ( + load_agent_ids_from_directory, + load_cache_from_directory, +) +from constants import ( + AGENT_CACHE_DIR, +) +from typing import Optional + +import streamlit as st + + +def update_selected_agent_with_id(selected_id: Optional[str] = None) -> None: + """Update selected agent with id.""" + # set session state + st.session_state.selected_id = ( + selected_id if selected_id != "Create a new agent" else None + ) + if st.session_state.selected_id is None: + st.session_state.selected_cache = None + else: + # load agent from directory + agent_cache = load_cache_from_directory( + str(AGENT_CACHE_DIR), st.session_state.selected_id + ) + st.session_state.selected_cache = agent_cache + + +## handler for sidebar specifically +def update_selected_agent() -> None: + """Update selected agent.""" + selected_id = st.session_state.agent_selector + + update_selected_agent_with_id(selected_id) + + +def add_sidebar() -> None: + """Add sidebar.""" + with st.sidebar: + st.session_state.cur_agent_ids = load_agent_ids_from_directory( + str(AGENT_CACHE_DIR) + ) + choices = ["Create a new agent"] + st.session_state.cur_agent_ids + + # by default, set index to 0. if value is in selected_id, set index to that + index = 0 + if "selected_id" in st.session_state.keys(): + if st.session_state.selected_id is not None: + index = choices.index(st.session_state.selected_id) + # display buttons + st.radio( + "Agents", + choices, + index=index, + on_change=update_selected_agent, + key="agent_selector", + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29