Skip to content

Commit

Permalink
feat: stream router (QuivrHQ#353)
Browse files Browse the repository at this point in the history
* wip: stream router

* feat: chatai streaming

* chore: add comments

* feat: streaming for chains

* chore: comments
  • Loading branch information
mattzcarey authored Jun 20, 2023
1 parent 90bd495 commit 3e753f2
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 1 deletion.
6 changes: 5 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from routes.crawl_routes import crawl_router
from routes.explore_routes import explore_router
from routes.misc_routes import misc_router
from routes.stream_routes import stream_router
from routes.upload_routes import upload_router
from routes.user_routes import user_router

Expand All @@ -19,12 +20,14 @@

add_cors_middleware(app)
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200)


@app.on_event("startup")
async def startup_event():
pypandoc.download_pandoc()


app.include_router(brain_router)
app.include_router(chat_router)
app.include_router(crawl_router)
Expand All @@ -33,3 +36,4 @@ async def startup_event():
app.include_router(upload_router)
app.include_router(user_router)
app.include_router(api_key_router)
app.include_router(stream_router)
123 changes: 123 additions & 0 deletions backend/routes/stream_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import asyncio
import os
from typing import AsyncIterable, Awaitable

from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
from logger import get_logger
from models.chats import ChatMessage
from models.settings import CommonsDep, common_dependencies
from models.users import User
from supabase import create_client
from utils.users import fetch_user_id_from_credentials
from vectorstore.supabase import CustomSupabaseVectorStore

logger = get_logger(__name__)

stream_router = APIRouter()

openai_api_key = os.getenv("OPENAI_API_KEY")
supabase_url = os.getenv("SUPABASE_URL")
supabase_service_key = os.getenv("SUPABASE_SERVICE_KEY")


async def send_message(
chat_message: ChatMessage, chain, callback
) -> AsyncIterable[str]:
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
resp = await fn
logger.debug("Done: %s", resp)
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
# Signal the aiter to stop.
event.set()

# Use the agenerate method for models.
# Use the acall method for chains.
task = asyncio.create_task(
wrap_done(
chain.acall(
{
"question": chat_message.question,
"chat_history": chat_message.history,
}
),
callback.done,
)
)

# Use the aiter method of the callback to stream the response with server-sent-events
async for token in callback.aiter():
logger.info("Token: %s", token)
yield f"data: {token}\n\n"

await task


def create_chain(commons: CommonsDep, current_user: User):
user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email})

embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)

supabase_client = create_client(supabase_url, supabase_service_key)

vector_store = CustomSupabaseVectorStore(
supabase_client, embeddings, table_name="vectors", user_id=user_id
)

generator_llm = ChatOpenAI(
temperature=0,
)

# Callback provides the on_llm_new_token method
callback = AsyncIteratorCallbackHandler()

streaming_llm = ChatOpenAI(
temperature=0,
streaming=True,
callbacks=[callback],
)
question_generator = LLMChain(
llm=generator_llm,
prompt=CONDENSE_QUESTION_PROMPT,
)
doc_chain = load_qa_chain(
llm=streaming_llm,
chain_type="stuff",
)

return (
ConversationalRetrievalChain(
combine_docs_chain=doc_chain,
question_generator=question_generator,
retriever=vector_store.as_retriever(),
verbose=True,
),
callback,
)


@stream_router.post("/stream", dependencies=[Depends(AuthBearer())], tags=["Stream"])
async def stream(
chat_message: ChatMessage,
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
commons = common_dependencies()

qa_chain, callback = create_chain(commons, current_user)

return StreamingResponse(
send_message(chat_message, qa_chain, callback),
media_type="text/event-stream",
)

0 comments on commit 3e753f2

Please sign in to comment.