Skip to content

Commit

Permalink
skip fncall postprocessing when not in fncall mode
Browse files Browse the repository at this point in the history
  • Loading branch information
JianxinMa committed Mar 4, 2024
1 parent 5080f1c commit 77179cb
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 14 deletions.
28 changes: 19 additions & 9 deletions qwen_agent/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,25 +71,29 @@ def chat(

messages = self._preprocess_messages(messages)
if functions:
fncall_mode = True
output = self._chat_with_functions(
messages=messages,
functions=functions,
stream=stream,
delta_stream=delta_stream,
)
else:
fncall_mode = False
output = self._chat(
messages,
stream=stream,
delta_stream=delta_stream,
)

if isinstance(output, list):
output = self._postprocess_messages(output)
output = self._postprocess_messages(output,
fncall_mode=fncall_mode)
return self._convert_messages_to_target_type(
output, _return_message_type)
else:
output = self._postprocess_messages_iterator(output)
output = self._postprocess_messages_iterator(
output, fncall_mode=fncall_mode)
return self._convert_messages_iterator_to_target_type(
output, _return_message_type)

Expand Down Expand Up @@ -132,16 +136,19 @@ def _chat_no_stream(
def _preprocess_messages(self, messages: List[Message]) -> List[Message]:
return self._format_as_multimodal_messages(messages)

def _postprocess_messages(self, messages: List[Message]) -> List[Message]:
def _postprocess_messages(self, messages: List[Message],
fncall_mode: bool) -> List[Message]:
messages = self._format_as_multimodal_messages(messages)
messages = self._postprocess_stop_words(messages)
return messages

def _postprocess_messages_iterator(
self,
messages: Iterator[List[Message]]) -> Iterator[List[Message]]:
self,
messages: Iterator[List[Message]],
fncall_mode: bool,
) -> Iterator[List[Message]]:
for m in messages:
m = self._postprocess_messages(m)
m = self._postprocess_messages(m, fncall_mode=fncall_mode)
if m:
yield m

Expand Down Expand Up @@ -219,7 +226,8 @@ def _format_as_multimodal_messages(

return multimodal_messages

def _postprocess_stop_words(self, messages: List[Message]) -> List[Message]:
def _postprocess_stop_words(self,
messages: List[Message]) -> List[Message]:
messages = copy.deepcopy(messages)
stop = self.generate_cfg.get('stop', [])

Expand All @@ -231,7 +239,8 @@ def _postprocess_stop_words(self, messages: List[Message]) -> List[Message]:
for i, item in enumerate(msg.content):
item_type, item_text = item.get_type_and_value()
if item_type == 'text':
truncated, item.text = _truncate_at_stop_word(text=item_text, stop=stop)
truncated, item.text = _truncate_at_stop_word(
text=item_text, stop=stop)
trunc_content.append(item)
if truncated:
break
Expand Down Expand Up @@ -261,11 +270,12 @@ def _postprocess_stop_words(self, messages: List[Message]) -> List[Message]:

return messages


def _truncate_at_stop_word(text: str, stop: List[str]):
truncated = False
for s in stop:
k = text.find(s)
if k >= 0:
truncated = True
text = text[:k]
return truncated, text
return truncated, text
8 changes: 5 additions & 3 deletions qwen_agent/llm/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def _preprocess_messages(self, messages: List[Message]) -> List[Message]:
messages = self._preprocess_fncall_messages(messages)
return messages

def _postprocess_messages(self, messages: List[Message]) -> List[Message]:
messages = super()._postprocess_messages(messages)
messages = self._postprocess_fncall_messages(messages)
def _postprocess_messages(self, messages: List[Message],
fncall_mode: bool) -> List[Message]:
messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode)
if fncall_mode:
messages = self._postprocess_fncall_messages(messages)
return messages

def _prepend_fncall_system(self, messages: List[Message],
Expand Down
4 changes: 4 additions & 0 deletions qwen_agent/llm/qwenvl_dashscope.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from http import HTTPStatus
from pprint import pformat
from typing import Dict, Iterator, List, Optional

import dashscope

from qwen_agent.llm.base import ModelServiceError, register_llm
from qwen_agent.llm.function_calling import BaseFnCallModel
from qwen_agent.log import logger

from .schema import CONTENT, ROLE, ContentItem, Message

Expand Down Expand Up @@ -33,6 +35,7 @@ def _chat_stream(
raise NotImplementedError

messages = [msg.model_dump() for msg in messages]
logger.debug(f'*{pformat(messages, indent=2)}*')
response = dashscope.MultiModalConversation.call(
model=self.model,
messages=messages,
Expand All @@ -55,6 +58,7 @@ def _chat_no_stream(
messages: List[Message],
) -> List[Message]:
messages = [msg.model_dump() for msg in messages]
logger.debug(f'*{pformat(messages, indent=2)}*')
response = dashscope.MultiModalConversation.call(
model=self.model,
messages=messages,
Expand Down
5 changes: 3 additions & 2 deletions qwen_agent/llm/text_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ def _preprocess_messages(self, messages: List[Message]) -> List[Message]:
messages = self._format_as_text_messages(messages)
return messages

def _postprocess_messages(self, messages: List[Message]) -> List[Message]:
messages = super()._postprocess_messages(messages)
def _postprocess_messages(self, messages: List[Message],
fncall_mode: bool) -> List[Message]:
messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode)
messages = self._format_as_text_messages(messages)
return messages

Expand Down

0 comments on commit 77179cb

Please sign in to comment.