Skip to content

Commit

Permalink
hotfix non-openai LLMs (run-llama#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu authored Nov 23, 2023
1 parent 0221f79 commit cbf4ad9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
23 changes: 14 additions & 9 deletions agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from llama_index.agent.types import BaseAgent
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.llms.openai_utils import is_function_calling_model
from llama_index.chat_engine import CondensePlusContextChatEngine
from builder_config import BUILDER_LLM
from typing import Dict, Tuple, Any
import streamlit as st
Expand Down Expand Up @@ -88,9 +89,11 @@ def load_agent(
tools: List,
llm: LLM,
system_prompt: str,
extra_kwargs: Optional[Dict] = None,
**kwargs: Any
) -> BaseAgent:
"""Load agent."""
extra_kwargs = extra_kwargs or {}
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
# get OpenAI Agent
agent = OpenAIAgent.from_tools(
Expand All @@ -100,14 +103,15 @@ def load_agent(
**kwargs
)
else:
agent = ReActAgent.from_tools(
tools=tools,
llm=llm,
react_chat_formatter=ReActChatFormatter(
system_header=system_prompt + "\n" + REACT_CHAT_SYSTEM_HEADER,
),
**kwargs
if "vector_index" not in extra_kwargs:
raise ValueError("Must pass in vector index for CondensePlusContextChatEngine.")
vector_index = cast(VectorStoreIndex, extra_kwargs["vector_index"])
rag_params = cast(RAGParams, extra_kwargs["rag_params"])
# use condense + context chat engine
agent = CondensePlusContextChatEngine.from_defaults(
vector_index.as_retriever(similarity_top_k=rag_params.top_k),
)

return agent


Expand All @@ -117,7 +121,7 @@ class RAGParams(BaseModel):
Parameters used to configure a RAG pipeline.
"""
include_summarization: bool = Field(default=False, description="Whether to include summarization in the RAG pipeline.")
include_summarization: bool = Field(default=False, description="Whether to include summarization in the RAG pipeline. (only for GPT-4)")
top_k: int = Field(default=2, description="Number of documents to retrieve from vector store.")
chunk_size: int = Field(default=1024, description="Chunk size for vector store.")
embed_model: str = Field(
Expand Down Expand Up @@ -315,7 +319,8 @@ def create_agent(self) -> None:
return "System prompt not set yet. Please set system prompt first."

agent = load_agent(
all_tools, llm=llm, system_prompt=self._cache.system_prompt, verbose=True
all_tools, llm=llm, system_prompt=self._cache.system_prompt, verbose=True,
extra_kwargs={"vector_index": vector_index, "rag_params": rag_params}
)

self._cache.agent = agent
Expand Down
2 changes: 1 addition & 1 deletion pages/2_⚙️_RAG_Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
value=",".join(agent_builder.cache.file_paths),
disabled=True
)
include_summarization_st = st.checkbox("Include Summarization", value=rag_params.include_summarization)
include_summarization_st = st.checkbox("Include Summarization (only works for GPT-4)", value=rag_params.include_summarization)
top_k_st = st.number_input("Top K", value=rag_params.top_k)
chunk_size_st = st.number_input("Chunk Size", value=rag_params.chunk_size)
embed_model_st = st.text_input("Embed Model", value=rag_params.embed_model)
Expand Down

0 comments on commit cbf4ad9

Please sign in to comment.