Skip to content

Commit

Permalink
mini refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rashadphz committed May 19, 2024
1 parent 7f090de commit addf683
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 41 deletions.
25 changes: 11 additions & 14 deletions src/backend/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
FinalResponseStream,
Message,
RelatedQueriesStream,
SearchResult,
SearchResultStream,
StreamEndStream,
StreamEvent,
Expand Down Expand Up @@ -68,6 +69,12 @@ def get_llm(model: ChatModel) -> LLM:
raise ValueError(f"Unknown model: {model}")


def format_context(search_results: List[SearchResult]) -> str:
return "\n\n".join(
[f"Citation {i+1}. {str(result)}" for i, result in enumerate(search_results)]
)


async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseEvent]:

try:
Expand All @@ -90,13 +97,11 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE
images = search_response.images

# Only create the task first if the model is not local
related_queries_task = (
asyncio.create_task(
related_queries_task = None
if not is_local_model(request.model):
related_queries_task = asyncio.create_task(
generate_related_queries(query, search_results, request.model)
)
if not is_local_model(request.model)
else None
)

yield ChatResponseEvent(
event=StreamEvent.SEARCH_RESULTS,
Expand All @@ -106,15 +111,8 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE
),
)

context_str = "\n\n".join(
[
f"Citation {i+1}. {str(result)}"
for i, result in enumerate(search_results)
]
)

fmt_qa_prompt = CHAT_PROMPT.format(
my_context=context_str,
my_context=format_context(search_results),
my_query=query,
)

Expand All @@ -127,7 +125,6 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE
data=TextChunkStream(text=completion.delta or ""),
)

# For local models, generate the answer before the related queries
related_queries = await (
related_queries_task
if related_queries_task
Expand Down
29 changes: 2 additions & 27 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import json
import os
from typing import Generator
from backend.utils import is_local_model, strtobool
from backend.utils import strtobool

from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
from backend.chat import stream_qa_objects
from backend.validators import validate_model
from sse_starlette.sse import EventSourceResponse, ServerSentEvent

import logfire
Expand All @@ -23,7 +24,6 @@
from slowapi.util import get_ipaddr
from slowapi.errors import RateLimitExceeded

from backend.constants import ChatModel


load_dotenv()
Expand Down Expand Up @@ -79,31 +79,6 @@ def create_error_event(detail: str):
)


def validate_model(model: ChatModel):
if model in {ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o}:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY environment variable not found")
if model == ChatModel.GPT_4o:
GPT4_ENABLED = strtobool(os.getenv("GPT4_ENABLED", True))
if not GPT4_ENABLED:
raise ValueError(
"GPT4-o has been disabled. Please try a different model or self-host the app by following the instructions here: https://github.com/rashadphz/farfalle"
)

elif model == ChatModel.LLAMA_3_70B:
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY environment variable not found")
elif is_local_model(model):
LOCAL_MODELS_ENABLED = strtobool(os.getenv("ENABLE_LOCAL_MODELS", False))
if not LOCAL_MODELS_ENABLED:
raise ValueError("Local models are not enabled")
else:
raise ValueError("Invalid model")
return True


@app.post("/chat")
@app.state.limiter.limit("10/hour")
async def chat(
Expand Down
28 changes: 28 additions & 0 deletions src/backend/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from backend.constants import ChatModel
from backend.utils import is_local_model, strtobool
import os


def validate_model(model: ChatModel):
if model in {ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o}:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY environment variable not found")
if model == ChatModel.GPT_4o:
GPT4_ENABLED = strtobool(os.getenv("GPT4_ENABLED", True))
if not GPT4_ENABLED:
raise ValueError(
"GPT4-o has been disabled. Please try a different model or self-host the app by following the instructions here: https://github.com/rashadphz/farfalle"
)

elif model == ChatModel.LLAMA_3_70B:
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY environment variable not found")
elif is_local_model(model):
LOCAL_MODELS_ENABLED = strtobool(os.getenv("ENABLE_LOCAL_MODELS", False))
if not LOCAL_MODELS_ENABLED:
raise ValueError("Local models are not enabled")
else:
raise ValueError("Invalid model")
return True

0 comments on commit addf683

Please sign in to comment.