Skip to content

Commit

Permalink
fix: return error message in SSE for stream mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jameszyao authored and SimsonW committed Mar 11, 2024
1 parent f75df98 commit 83d7227
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 38 deletions.
19 changes: 5 additions & 14 deletions backend/app/routes/assistant/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Dict

from tkhelper.schemas.base import BaseDataResponse
from tkhelper.error import ErrorCode, raise_http_error

from app.services.assistant import get_assistant_and_chat, NormalSession, StreamSession
from app.schemas.assistant.generate import MessageGenerateRequest
Expand Down Expand Up @@ -34,32 +33,24 @@ async def api_chat_generate(
auth_info: Dict = Depends(auth_info_required),
):
system_prompt_variables = payload.system_prompt_variables
stream = payload.stream

assistant, chat = await get_assistant_and_chat(assistant_id, chat_id)

if await chat.is_chat_locked():
raise_http_error(
ErrorCode.OBJECT_LOCKED,
message="Chat is locked by another generation process.",
)

if payload.stream or payload.debug:
session = StreamSession(
assistant=assistant,
chat=chat,
stream=payload.stream,
debug=payload.debug,
)
await session.prepare(stream, system_prompt_variables, retrival_log=payload.debug)
await chat.lock()
return StreamingResponse(session.stream_generate(), media_type="text/event-stream")
return StreamingResponse(
session.stream_generate(system_prompt_variables),
media_type="text/event-stream",
)

else:
session = NormalSession(
assistant=assistant,
chat=chat,
)
await session.prepare(stream, system_prompt_variables, retrival_log=False)
await chat.lock()
return await session.generate()
return await session.generate(system_prompt_variables)
12 changes: 8 additions & 4 deletions backend/app/services/assistant/generation/normal_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import HTTPException
from typing import Dict
from tkhelper.error import raise_http_error, ErrorCode
from tkhelper.schemas import BaseDataResponse
import logging
Expand All @@ -15,10 +16,13 @@ class NormalSession(Session):
def __init__(self, assistant: Assistant, chat: Chat):
super().__init__(assistant, chat)

async def generate(self):
function_calls_round_index = 0

async def generate(self, system_prompt_variables: Dict):
try:
await self.prepare(False, system_prompt_variables, retrieval_log=False)
await self.chat.lock()

function_calls_round_index = 0

while True:
try:
chat_completion_assistant_message, chat_completion_function_calls_dict_list = await self.inference()
Expand Down Expand Up @@ -53,7 +57,7 @@ async def generate(self):

except MessageGenerationException as e:
logger.error(f"NormalSession.generate: MessageGenerationException error = {e}")
raise_http_error(ErrorCode.INTERNAL_SERVER_ERROR, message=str(e))
raise_http_error(ErrorCode.GENERATION_ERROR, message=str(e))

except Exception as e:
logger.error(f"NormalSession.generate: Exception error = {e}")
Expand Down
6 changes: 3 additions & 3 deletions backend/app/services/assistant/generation/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def create_assistant_message(self, content_text: str):
metadata={},
)

async def prepare(self, stream: bool, system_prompt_variables: Dict, retrival_log: bool = False):
async def prepare(self, stream: bool, system_prompt_variables: Dict, retrieval_log: bool = False):
# 1. Get model
self.model = await get_model(self.assistant.model_id)

Expand Down Expand Up @@ -91,7 +91,7 @@ async def prepare(self, stream: bool, system_prompt_variables: Dict, retrival_lo
)
if retrieval_query_text:
retrieval_event_id = generate_random_event_id()
if retrival_log:
if retrieval_log:
retrieval_log_input = build_retrieval_input_log_dict(
session_id=self.session_id,
event_id=retrieval_event_id,
Expand All @@ -105,7 +105,7 @@ async def prepare(self, stream: bool, system_prompt_variables: Dict, retrival_lo
query_text=retrieval_query_text,
)

if retrival_log:
if retrieval_log:
retrieval_log_output = build_retrieval_output_log_dict(
session_id=self.session_id,
event_id=retrieval_event_id,
Expand Down
17 changes: 10 additions & 7 deletions backend/app/services/assistant/generation/stream_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def error_message(message: str):
return {
"object": "Error",
"code": ErrorCode.INTERNAL_SERVER_ERROR,
"code": "GENERATION_ERROR",
"message": message,
}

Expand Down Expand Up @@ -52,13 +52,16 @@ async def stream_inference(self):
temp_data.update({"object": "MessageChunk"})
yield MESSAGE_CHUNK, temp_data

async def stream_generate(self):
if self.prepare_logs:
for log_dict in self.prepare_logs:
yield f"data: {json.dumps(log_dict)}\n\n"
await asyncio.sleep(0.1)

async def stream_generate(self, system_prompt_variables: Dict):
try:
await self.prepare(True, system_prompt_variables, retrieval_log=self.debug)
await self.chat.lock()

if self.prepare_logs:
for log_dict in self.prepare_logs:
yield f"data: {json.dumps(log_dict)}\n\n"
await asyncio.sleep(0.1)

function_calls_round_index = 0
while True:
chat_completion_function_calls_dict_list = None
Expand Down
13 changes: 3 additions & 10 deletions backend/app/services/assistant/generation/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import re
from typing import Dict, List, Optional, Tuple
from tkhelper.utils import generate_random_id
from tkhelper.error import raise_http_error, ErrorCode

from app.models import Assistant, RetrievalMethod, Chat, RetrievalResult
from app.services.retrieval import query_retrievals
from app.services.retrieval.retrieval import query_retrievals


class MessageGenerationException(Exception):
Expand Down Expand Up @@ -87,18 +86,12 @@ async def get_chat_memory_messages(chat: Chat):
if chat_memory_messages:
last_message = chat_memory_messages[-1]
if last_message.role == "assistant":
raise_http_error(
ErrorCode.REQUEST_VALIDATION_ERROR,
message="Cannot generate another assistant message after an assistant message.",
)
raise MessageGenerationException("Cannot generate another assistant message after an assistant message.")

# Ensure there is at least one user message in the chat memory
user_message_count = sum(1 for message in chat_memory_messages if message.role == "user")
if user_message_count == 0:
raise_http_error(
ErrorCode.REQUEST_VALIDATION_ERROR,
message="There is no user message in the chat context.",
)
raise MessageGenerationException("There is no user message in the chat context.")

message_dicts = [message.model_dump() for message in chat_memory_messages]
return message_dicts
Expand Down
2 changes: 2 additions & 0 deletions backend/tkhelper/error/error_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ErrorCode(str, Enum):
ACTION_API_REQUEST_ERROR = "ACTION_API_REQUEST_ERROR"
OBJECT_LOCKED = "OBJECT_LOCKED"
INVALID_REQUEST = "INVALID_REQUEST"
GENERATION_ERROR = "GENERATION_ERROR"


error_messages = {
Expand All @@ -55,6 +56,7 @@ class ErrorCode(str, Enum):
ErrorCode.ACTION_API_REQUEST_ERROR: {"status_code": 400, "message": "Action API request error."},
ErrorCode.OBJECT_LOCKED: {"status_code": 423, "message": "Object locked."},
ErrorCode.INVALID_REQUEST: {"status_code": 400, "message": "Invalid request."},
ErrorCode.GENERATION_ERROR: {"status_code": 500, "message": "Generation error occurred."},
}

assert len(error_messages) == len(ErrorCode)
Expand Down

0 comments on commit 83d7227

Please sign in to comment.