Skip to content

Commit

Permalink
feat: add token count for messages
Browse files Browse the repository at this point in the history
  • Loading branch information
jameszyao committed Jan 22, 2024
1 parent e45ca9e commit b433321
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 6 deletions.
7 changes: 5 additions & 2 deletions backend/common/database_ops/assistant/message/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ async def create_message(
chat: Chat,
role: MessageRole,
content: MessageContent,
num_tokens: int,
metadata: Dict[str, str],
updated_chat_memory: ChatMemory,
) -> Message:
Expand All @@ -19,6 +20,7 @@ async def create_message(
:param chat: the chat where the message belongs to
:param role: the message role, user or assistant
:param content: the message content
:param num_tokens: the number of content tokens
:param metadata: the message metadata
:param updated_chat_memory: the chat memory to update
:return: the created message
Expand All @@ -32,14 +34,15 @@ async def create_message(
# 1. insert message into database
await conn.execute(
"""
INSERT INTO message (message_id, chat_id, assistant_id, role, content, metadata)
VALUES ($1, $2, $3, $4, $5, $6)
INSERT INTO message (message_id, chat_id, assistant_id, role, content, num_tokens, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
new_message_id,
chat.chat_id,
chat.assistant_id,
role.value,
content.model_dump_json(),
num_tokens,
json.dumps(metadata),
)

Expand Down
3 changes: 3 additions & 0 deletions backend/common/models/assistant/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Message(BaseModel):
assistant_id: str
role: MessageRole
content: MessageContent
num_tokens: int
metadata: Dict
updated_timestamp: int
created_timestamp: int
Expand All @@ -70,6 +71,7 @@ def build(cls, row):
assistant_id=row["assistant_id"],
role=MessageRole(row["role"]),
content=MessageContent(**load_json_attr(row, "content", {})),
num_tokens=row["num_tokens"],
metadata=load_json_attr(row, "metadata", {}),
updated_timestamp=row["updated_timestamp"],
created_timestamp=row["created_timestamp"],
Expand All @@ -83,6 +85,7 @@ def to_dict(self, purpose: SerializePurpose):
"assistant_id": self.assistant_id,
"role": self.role.value,
"content": self.content.model_dump(),
"num_tokens": self.num_tokens,
"metadata": self.metadata,
"updated_timestamp": self.updated_timestamp,
"created_timestamp": self.created_timestamp,
Expand Down
3 changes: 1 addition & 2 deletions backend/common/models/retrieval/tokenizer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,4 @@ def get_tokenizer(type: TokenizerType):
return tokenizer


def default_tokenizer():
return get_tokenizer(TokenizerType.TIKTOKEN)
default_tokenizer = get_tokenizer(TokenizerType.TIKTOKEN)
5 changes: 5 additions & 0 deletions backend/common/services/assistant/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ListResult,
MessageRole,
ChatMemory,
default_tokenizer,
)
from common.database_ops.assistant import message as db_message
from common.error import ErrorCode, raise_http_error
Expand Down Expand Up @@ -89,11 +90,15 @@ async def create_message(
chat: Chat = await get_chat(assistant_id=assistant_id, chat_id=chat_id)
updated_chat_memory: ChatMemory = await chat.memory.update_memory(new_message_text=content.text, role=role.value)

# count tokens
num_tokens = default_tokenizer.count_tokens(content.text)

# create message
message = await db_message.create_message(
chat=chat,
role=role,
content=content,
num_tokens=num_tokens,
metadata=metadata,
updated_chat_memory=updated_chat_memory,
)
Expand Down
4 changes: 2 additions & 2 deletions backend/common/services/retrieval/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def create_chunk(
embedding = embeddings[0]

# create record
num_tokens = default_tokenizer().count_tokens(content)
num_tokens = default_tokenizer.count_tokens(content)
record = await db_chunk.create_chunk(
collection=collection,
content=content,
Expand Down Expand Up @@ -274,7 +274,7 @@ async def update_chunk(
embedding = embeddings[0]

# update chunk
num_tokens = default_tokenizer().count_tokens(content)
num_tokens = default_tokenizer.count_tokens(content)

record = await db_chunk.update_chunk(
collection=collection,
Expand Down

0 comments on commit b433321

Please sign in to comment.