forked from TaskingAI/TaskingAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
36 changed files
with
2,990 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .chat import * | ||
from .generation import * | ||
from .message import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from app.models import Assistant, Chat, ChatMemory | ||
from app.operators import chat_ops, assistant_ops | ||
|
||
__all__ = [ | ||
"get_chat", | ||
"update_chat_memory", | ||
"get_assistant_and_chat", | ||
] | ||
|
||
|
||
async def update_chat_memory( | ||
chat: Chat, | ||
memory: ChatMemory, | ||
) -> Chat: | ||
""" | ||
Update the chat memory | ||
:param chat: the chat to update | ||
:param memory: the chat memory to update | ||
:return: the updated chat | ||
""" | ||
|
||
chat = await chat_ops.update( | ||
assistant_id=chat.assistant_id, | ||
chat_id=chat.chat_id, | ||
update_dict={"memory": memory.model_dump()}, | ||
) | ||
|
||
return chat | ||
|
||
|
||
async def get_chat( | ||
assistant_id: str, | ||
chat_id: str, | ||
) -> Chat: | ||
""" | ||
Get chat | ||
:param assistant_id: the assistant id | ||
:param chat_id: the chat id | ||
:return: the chat | ||
""" | ||
chat: Chat = await chat_ops.get( | ||
assistant_id=assistant_id, | ||
chat_id=chat_id, | ||
) | ||
return chat | ||
|
||
|
||
async def get_assistant_and_chat( | ||
assistant_id: str, | ||
chat_id: str, | ||
) -> (Assistant, Chat): | ||
""" | ||
Get chat | ||
:param assistant_id: the assistant id | ||
:param chat_id: the chat id | ||
:return: the chat | ||
""" | ||
assistant: Assistant = await assistant_ops.get( | ||
assistant_id=assistant_id, | ||
) | ||
chat: Chat = await chat_ops.get( | ||
assistant_id=assistant_id, | ||
chat_id=chat_id, | ||
) | ||
return assistant, chat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .normal_session import NormalSession | ||
from .stream_session import StreamSession |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from typing import Dict, List | ||
from tkhelper.utils import current_timestamp_int_milliseconds | ||
|
||
from app.models import MessageGenerationLog, Model, RetrievalResult, ToolInput, ToolOutput | ||
|
||
|
||
def build_retrieval_input_log_dict( | ||
session_id: str, | ||
event_id: str, | ||
query_text: str, | ||
top_k: int, | ||
): | ||
return MessageGenerationLog( | ||
session_id=session_id, | ||
event="retrieval", | ||
event_id=event_id, | ||
event_step="input", | ||
timestamp=current_timestamp_int_milliseconds(), | ||
content={ | ||
"query_text": query_text, | ||
"top_k": top_k, | ||
}, | ||
).model_dump() | ||
|
||
|
||
def build_retrieval_output_log_dict( | ||
session_id: str, | ||
event_id: str, | ||
retrieval_result: List[RetrievalResult], | ||
): | ||
return MessageGenerationLog( | ||
session_id=session_id, | ||
event="retrieval", | ||
event_id=event_id, | ||
event_step="output", | ||
timestamp=current_timestamp_int_milliseconds(), | ||
content={ | ||
"result": retrieval_result, | ||
}, | ||
).model_dump() | ||
|
||
|
||
def build_tool_input_log_dict( | ||
session_id: str, | ||
event_id: str, | ||
tool_input: ToolInput, | ||
): | ||
return MessageGenerationLog( | ||
session_id=session_id, | ||
event="tool", | ||
event_id=event_id, | ||
event_step="input", | ||
timestamp=current_timestamp_int_milliseconds(), | ||
content=tool_input.model_dump(), | ||
).model_dump() | ||
|
||
|
||
def build_tool_output_log_dict( | ||
session_id: str, | ||
event_id: str, | ||
# tool_name: str, | ||
tool_output: ToolOutput, | ||
): | ||
return MessageGenerationLog( | ||
session_id=session_id, | ||
event="tool", | ||
event_id=event_id, | ||
event_step="output", | ||
timestamp=current_timestamp_int_milliseconds(), | ||
content=tool_output.model_dump(), | ||
).model_dump() | ||
|
||
|
||
def build_chat_completion_input_log_dict( | ||
session_id: str, | ||
event_id: str, | ||
model: Model, | ||
messages: List[Dict], | ||
functions: List[Dict], | ||
): | ||
log = MessageGenerationLog( | ||
session_id=session_id, | ||
event="chat_completion", | ||
event_id=event_id, | ||
event_step="input", | ||
timestamp=current_timestamp_int_milliseconds(), | ||
content={ | ||
"model_id": model.model_id, | ||
"model_schema_id": model.model_schema_id, | ||
"provider_model_id": model.provider_model_id, | ||
"messages": messages, | ||
"functions": functions, | ||
}, | ||
) | ||
return log.model_dump() | ||
|
||
|
||
def build_chat_completion_output_log_dict( | ||
session_id: str, | ||
event_id: str, | ||
model: Model, | ||
message: Dict, | ||
): | ||
return MessageGenerationLog( | ||
session_id=session_id, | ||
event="chat_completion", | ||
event_id=event_id, | ||
event_step="output", | ||
timestamp=current_timestamp_int_milliseconds(), | ||
content={ | ||
"model_id": model.model_id, | ||
"model_schema_id": model.model_schema_id, | ||
"provider_model_id": model.provider_model_id, | ||
"message": message, | ||
}, | ||
).model_dump() |
65 changes: 65 additions & 0 deletions
65
backend/app/services/assistant/generation/normal_session.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from fastapi import HTTPException | ||
from tkhelper.error import raise_http_error, ErrorCode | ||
from tkhelper.schemas import BaseDataResponse | ||
import logging | ||
|
||
from app.models import Assistant, Chat | ||
|
||
from .session import Session | ||
from .utils import MessageGenerationException | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class NormalSession(Session): | ||
def __init__(self, assistant: Assistant, chat: Chat): | ||
super().__init__(assistant, chat) | ||
|
||
async def generate(self): | ||
function_calls_round_index = 0 | ||
|
||
try: | ||
while True: | ||
try: | ||
chat_completion_assistant_message, chat_completion_function_calls_dict_list = await self.inference() | ||
except HTTPException as e: | ||
raise MessageGenerationException(f"Error occurred in chat completion inference. {e.detail}") | ||
except Exception as e: | ||
raise MessageGenerationException(f"Error occurred in chat completion inference") | ||
|
||
logger.debug(f"chat_completion_assistant_message = {chat_completion_assistant_message}") | ||
logger.debug(f"chat_completion_function_calls_dict_list = {chat_completion_function_calls_dict_list}") | ||
|
||
if chat_completion_function_calls_dict_list: | ||
function_calls_round_index += 1 | ||
try: | ||
await self.use_tool( | ||
chat_completion_function_calls_dict_list, round_index=function_calls_round_index, log=False | ||
) | ||
async for _ in self.run_tools(chat_completion_function_calls_dict_list): | ||
pass | ||
except MessageGenerationException as e: | ||
logger.error(f"MessageGenerationException occurred in using the tools: {e}") | ||
raise e | ||
except Exception as e: | ||
logger.error(f"MessageGenerationException occurred in using the tools: {e}") | ||
raise MessageGenerationException(f"Error occurred in using the tools") | ||
|
||
else: | ||
break | ||
|
||
message = await self.create_assistant_message(chat_completion_assistant_message["content"]) | ||
return BaseDataResponse(data=message.to_response_dict()) | ||
|
||
except MessageGenerationException as e: | ||
logger.error(f"NormalSession.generate: MessageGenerationException error = {e}") | ||
raise_http_error(ErrorCode.INTERNAL_SERVER_ERROR, message=str(e)) | ||
|
||
except Exception as e: | ||
logger.error(f"NormalSession.generate: Exception error = {e}") | ||
raise_http_error( | ||
ErrorCode.INTERNAL_SERVER_ERROR, message=str("Assistant message not generated due to an unknown error.") | ||
) | ||
|
||
finally: | ||
await self.chat.unlock() |
Oops, something went wrong.