Skip to content

Commit

Permalink
feat: update web services
Browse files Browse the repository at this point in the history
  • Loading branch information
jameszyao authored and SimsonW committed Mar 11, 2024
1 parent 843e977 commit 0eeaca7
Show file tree
Hide file tree
Showing 36 changed files with 2,990 additions and 13 deletions.
2 changes: 1 addition & 1 deletion backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ COPY ./requirements.txt ./
RUN pip3 install --no-cache-dir -r requirements.txt

# Copy the rest of the application
COPY . ./
COPY ./app ./app

# Expose port 8000
EXPOSE 8000
Expand Down
6 changes: 3 additions & 3 deletions backend/app/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from .utils import validate_list_cursors


class BaseSuccessEmptyResponse(BaseModel):
class BaseEmptyResponse(BaseModel):
status: str = Field("success")


class BaseSuccessDataResponse(BaseModel):
class BaseDataResponse(BaseModel):
status: str = Field("success")
data: Optional[Any] = None

Expand All @@ -18,7 +18,7 @@ class BaseErrorResponse(BaseModel):
error: Dict[str, Any]


class BaseSuccessListResponse(BaseModel):
class BaseListResponse(BaseModel):
status: str = Field("success")
data: Any
fetched_count: int
Expand Down
54 changes: 47 additions & 7 deletions backend/app/schemas/model/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic import BaseModel, Field
from typing import Dict, Optional
from pydantic import BaseModel, Field, model_validator
from typing import Dict, Optional, Any
from app.models import ModelType
from ..utils import check_update_keys

__all__ = [
"ModelCreateRequest",
Expand All @@ -9,12 +11,50 @@

# POST /projects/{project_id}/models/create
class ModelCreateRequest(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
model_schema_id: str = Field(..., min_length=1, max_length=50)
credentials: Dict = Field({})
name: str = Field(..., min_length=1, max_length=256, description="The name of the model.")
model_schema_id: str = Field(..., min_length=1, max_length=127, description="The provider_model_id of the model.")
provider_model_id: Optional[str] = Field(
None, min_length=1, max_length=255, description="The provider_model_id of the model."
)
type: Optional[ModelType] = Field(None, description="The type of the model.", examples=["text_embedding"])
credentials: Dict = Field(..., description="The credentials of the model.")
properties: Optional[Dict] = Field(None, description="The custom model properties.")


# POST /projects/{project_id}/models/update
class ModelUpdateRequest(BaseModel):
name: Optional[str] = Field(default=None, min_length=1, max_length=255)
credentials: Optional[Dict] = Field(default=None)
name: Optional[str] = Field(default=None, min_length=1, max_length=255, description="The name of the model.")

model_schema_id: Optional[str] = Field(
None, min_length=1, max_length=127, description="The provider_model_id of the model."
)
provider_model_id: Optional[str] = Field(
None, min_length=1, max_length=255, description="The provider_model_id of the model."
)
type: Optional[ModelType] = Field(None, description="The type of the model.", examples=["text_embedding"])
credentials: Optional[Dict] = Field(None, description="The credentials of the model.")
properties: Optional[Dict] = Field(None, description="The custom model properties.")

@model_validator(mode="before")
def custom_validate(cls, data: Any):
check_update_keys(data, ["name", "credentials"])

model_schema_id_exist = data.get("model_schema_id") is not None
provider_model_id_exist = data.get("provider_model_id") is not None
model_type_exist = data.get("model_type") is not None
credentials_exist = data.get("credentials") is not None
properties_exist = data.get("properties") is not None

if model_schema_id_exist and not credentials_exist:
raise ValueError("model_schema_id and credentials must be updated together.")

if provider_model_id_exist and not credentials_exist:
raise ValueError("provider_model_id can only be updated with credentials.")

if model_type_exist and not credentials_exist:
raise ValueError("model_type can only be updated with credentials.")

if properties_exist and not credentials_exist:
raise ValueError("properties can only be updated with credentials.")

return data
Empty file.
3 changes: 3 additions & 0 deletions backend/app/services/assistant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .chat import *
from .generation import *
from .message import *
65 changes: 65 additions & 0 deletions backend/app/services/assistant/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from app.models import Assistant, Chat, ChatMemory
from app.operators import chat_ops, assistant_ops

__all__ = [
"get_chat",
"update_chat_memory",
"get_assistant_and_chat",
]


async def update_chat_memory(
chat: Chat,
memory: ChatMemory,
) -> Chat:
"""
Update the chat memory
:param chat: the chat to update
:param memory: the chat memory to update
:return: the updated chat
"""

chat = await chat_ops.update(
assistant_id=chat.assistant_id,
chat_id=chat.chat_id,
update_dict={"memory": memory.model_dump()},
)

return chat


async def get_chat(
assistant_id: str,
chat_id: str,
) -> Chat:
"""
Get chat
:param assistant_id: the assistant id
:param chat_id: the chat id
:return: the chat
"""
chat: Chat = await chat_ops.get(
assistant_id=assistant_id,
chat_id=chat_id,
)
return chat


async def get_assistant_and_chat(
assistant_id: str,
chat_id: str,
) -> (Assistant, Chat):
"""
Get chat
:param assistant_id: the assistant id
:param chat_id: the chat id
:return: the chat
"""
assistant: Assistant = await assistant_ops.get(
assistant_id=assistant_id,
)
chat: Chat = await chat_ops.get(
assistant_id=assistant_id,
chat_id=chat_id,
)
return assistant, chat
2 changes: 2 additions & 0 deletions backend/app/services/assistant/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .normal_session import NormalSession
from .stream_session import StreamSession
116 changes: 116 additions & 0 deletions backend/app/services/assistant/generation/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Dict, List
from tkhelper.utils import current_timestamp_int_milliseconds

from app.models import MessageGenerationLog, Model, RetrievalResult, ToolInput, ToolOutput


def build_retrieval_input_log_dict(
session_id: str,
event_id: str,
query_text: str,
top_k: int,
):
return MessageGenerationLog(
session_id=session_id,
event="retrieval",
event_id=event_id,
event_step="input",
timestamp=current_timestamp_int_milliseconds(),
content={
"query_text": query_text,
"top_k": top_k,
},
).model_dump()


def build_retrieval_output_log_dict(
session_id: str,
event_id: str,
retrieval_result: List[RetrievalResult],
):
return MessageGenerationLog(
session_id=session_id,
event="retrieval",
event_id=event_id,
event_step="output",
timestamp=current_timestamp_int_milliseconds(),
content={
"result": retrieval_result,
},
).model_dump()


def build_tool_input_log_dict(
session_id: str,
event_id: str,
tool_input: ToolInput,
):
return MessageGenerationLog(
session_id=session_id,
event="tool",
event_id=event_id,
event_step="input",
timestamp=current_timestamp_int_milliseconds(),
content=tool_input.model_dump(),
).model_dump()


def build_tool_output_log_dict(
session_id: str,
event_id: str,
# tool_name: str,
tool_output: ToolOutput,
):
return MessageGenerationLog(
session_id=session_id,
event="tool",
event_id=event_id,
event_step="output",
timestamp=current_timestamp_int_milliseconds(),
content=tool_output.model_dump(),
).model_dump()


def build_chat_completion_input_log_dict(
session_id: str,
event_id: str,
model: Model,
messages: List[Dict],
functions: List[Dict],
):
log = MessageGenerationLog(
session_id=session_id,
event="chat_completion",
event_id=event_id,
event_step="input",
timestamp=current_timestamp_int_milliseconds(),
content={
"model_id": model.model_id,
"model_schema_id": model.model_schema_id,
"provider_model_id": model.provider_model_id,
"messages": messages,
"functions": functions,
},
)
return log.model_dump()


def build_chat_completion_output_log_dict(
session_id: str,
event_id: str,
model: Model,
message: Dict,
):
return MessageGenerationLog(
session_id=session_id,
event="chat_completion",
event_id=event_id,
event_step="output",
timestamp=current_timestamp_int_milliseconds(),
content={
"model_id": model.model_id,
"model_schema_id": model.model_schema_id,
"provider_model_id": model.provider_model_id,
"message": message,
},
).model_dump()
65 changes: 65 additions & 0 deletions backend/app/services/assistant/generation/normal_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from fastapi import HTTPException
from tkhelper.error import raise_http_error, ErrorCode
from tkhelper.schemas import BaseDataResponse
import logging

from app.models import Assistant, Chat

from .session import Session
from .utils import MessageGenerationException

logger = logging.getLogger(__name__)


class NormalSession(Session):
def __init__(self, assistant: Assistant, chat: Chat):
super().__init__(assistant, chat)

async def generate(self):
function_calls_round_index = 0

try:
while True:
try:
chat_completion_assistant_message, chat_completion_function_calls_dict_list = await self.inference()
except HTTPException as e:
raise MessageGenerationException(f"Error occurred in chat completion inference. {e.detail}")
except Exception as e:
raise MessageGenerationException(f"Error occurred in chat completion inference")

logger.debug(f"chat_completion_assistant_message = {chat_completion_assistant_message}")
logger.debug(f"chat_completion_function_calls_dict_list = {chat_completion_function_calls_dict_list}")

if chat_completion_function_calls_dict_list:
function_calls_round_index += 1
try:
await self.use_tool(
chat_completion_function_calls_dict_list, round_index=function_calls_round_index, log=False
)
async for _ in self.run_tools(chat_completion_function_calls_dict_list):
pass
except MessageGenerationException as e:
logger.error(f"MessageGenerationException occurred in using the tools: {e}")
raise e
except Exception as e:
logger.error(f"MessageGenerationException occurred in using the tools: {e}")
raise MessageGenerationException(f"Error occurred in using the tools")

else:
break

message = await self.create_assistant_message(chat_completion_assistant_message["content"])
return BaseDataResponse(data=message.to_response_dict())

except MessageGenerationException as e:
logger.error(f"NormalSession.generate: MessageGenerationException error = {e}")
raise_http_error(ErrorCode.INTERNAL_SERVER_ERROR, message=str(e))

except Exception as e:
logger.error(f"NormalSession.generate: Exception error = {e}")
raise_http_error(
ErrorCode.INTERNAL_SERVER_ERROR, message=str("Assistant message not generated due to an unknown error.")
)

finally:
await self.chat.unlock()
Loading

0 comments on commit 0eeaca7

Please sign in to comment.