Skip to content

Commit

Permalink
bugfix: concat messages for function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
JianxinMa committed Feb 8, 2024
1 parent 7ce8016 commit c9dfe77
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
9 changes: 7 additions & 2 deletions examples/function_calling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Reference: https://platform.openai.com/docs/guides/function-calling
import json
import os

from qwen_agent.llm import get_chat_model

Expand Down Expand Up @@ -35,8 +36,12 @@ def run_conversation():
# Use the model service provided by DashScope:
'model': 'qwen-max',
'model_server': 'dashscope',
# 'api_key': 'YOUR_DASHSCOPE_API_KEY',
# It will use the `DASHSCOPE_API_KEY' environment variable if 'api_key' is not set.
'api_key': os.getenv('DASHSCOPE_API_KEY'),

# Use the model service provided by Together.AI:
# 'model': 'Qwen/Qwen1.5-14B-Chat',
# 'model_server': 'https://api.together.xyz', # api_base
# 'api_key': os.getenv('TOGETHER_API_KEY'),

# Use your own model service compatible with OpenAI API:
# 'model': 'Qwen/Qwen1.5-72B-Chat',
Expand Down
15 changes: 12 additions & 3 deletions qwen_agent/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,19 @@ def _chat_with_functions(
messages = self._prepend_fn_call_system(messages, functions)
messages = self._preprocess_messages(messages)

if messages and messages[-1][ROLE] == ASSISTANT:
if messages and messages[-1].role == ASSISTANT:
# Change the text completion to chat mode
assert len(messages) > 1 and messages[-2][ROLE] == USER
messages[-2][CONTENT] += '\n\n' + messages[-1][CONTENT]
assert len(messages) > 1 and messages[-2].role == USER
assert messages[-1].function_call is None
usr = messages[-2].content
bot = messages[-1].content
if isinstance(usr, str) and isinstance(bot, str):
usr = usr + '\n\n' + bot
elif isinstance(usr, list) and isinstance(bot, list):
usr = usr + [ContentItem(text='\n\n')] + bot
else:
raise NotImplementedError
messages[-2].content = usr
messages.pop()

logger.debug('==== Using chat format for function call===')
Expand Down
4 changes: 1 addition & 3 deletions qwen_agent/llm/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@ def __init__(self, cfg: Optional[Dict] = None):
openai.api_key = api_key
self._chat_complete_create = openai.ChatCompletion.create
else:
from openai import OpenAI

api_kwargs = {}
if api_base:
api_kwargs['base_url'] = api_base
if api_key:
api_kwargs['api_key'] = api_key

def _chat_complete_create(*args, **kwargs):
client = OpenAI(**api_kwargs)
client = openai.OpenAI(**api_kwargs)
return client.chat.completions.create(*args, **kwargs)

self._chat_complete_create = _chat_complete_create
Expand Down
10 changes: 9 additions & 1 deletion qwen_agent/log.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import logging
import os


def setup_logger(level=logging.INFO):
def setup_logger(level=None):

if level is None:
if int(os.getenv('QWEN_AGENT_DEBUG', '0').strip()):
level = logging.DEBUG
else:
level = logging.INFO

logger = logging.getLogger('qwen_agent_logger')
logger.setLevel(level)
handler = logging.StreamHandler()
Expand Down

0 comments on commit c9dfe77

Please sign in to comment.