Skip to content

Commit

Permalink
add max_retries for llm
Browse files Browse the repository at this point in the history
  • Loading branch information
JianxinMa committed Mar 11, 2024
1 parent c1ea87a commit b02331d
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 54 deletions.
10 changes: 7 additions & 3 deletions qwen_agent/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from qwen_agent.llm.base import LLM_REGISTRY

from .base import BaseChatModel
from .base import BaseChatModel, ModelServiceError
from .oai import TextChatAtOAI
from .qwen_dashscope import QwenChatAtDS
from .qwenvl_dashscope import QwenVLChatAtDS
Expand Down Expand Up @@ -59,6 +59,10 @@ def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel:


__all__ = [
'BaseChatModel', 'QwenChatAtDS', 'TextChatAtOAI', 'QwenVLChatAtDS',
'get_chat_model'
'BaseChatModel',
'QwenChatAtDS',
'TextChatAtOAI',
'QwenVLChatAtDS',
'get_chat_model',
'ModelServiceError',
]
125 changes: 111 additions & 14 deletions qwen_agent/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import copy
import random
import time
from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Optional, Union

from qwen_agent.utils.tokenization_qwen import tokenizer
from qwen_agent.utils.utils import (get_basename_from_url, has_chinese_chars,
is_image)
is_image, print_traceback)

from .schema import (ASSISTANT, DEFAULT_SYSTEM_MESSAGE, FUNCTION, SYSTEM, USER,
ContentItem, Message)
Expand All @@ -22,7 +24,18 @@ def decorator(cls):


class ModelServiceError(Exception):
pass

def __init__(self,
exception: Optional[Exception] = None,
code: Optional[str] = None,
message: Optional[str] = None):
if exception is not None:
super().__init__(exception)
else:
super().__init__(f'\nError code: {code}. Error message: {message}')
self.exception = exception
self.code = code
self.message = message


class BaseChatModel(ABC):
Expand All @@ -31,7 +44,9 @@ class BaseChatModel(ABC):
def __init__(self, cfg: Optional[Dict] = None):
cfg = cfg or {}
self.model = cfg.get('model', '')
self.generate_cfg = cfg.get('generate_cfg', {})
generate_cfg = copy.deepcopy(cfg.get('generate_cfg', {}))
self.max_retries = generate_cfg.pop('max_retries', 0)
self.generate_cfg = generate_cfg

def chat(
self,
Expand Down Expand Up @@ -71,21 +86,36 @@ def chat(
] + messages

messages = self._preprocess_messages(messages)

if functions:
fncall_mode = True
output = self._chat_with_functions(
messages=messages,
functions=functions,
stream=stream,
delta_stream=delta_stream,
)
else:
fncall_mode = False
output = self._chat(
messages,
stream=stream,
delta_stream=delta_stream,
)

def _call_model_service():
if fncall_mode:
return self._chat_with_functions(
messages=messages,
functions=functions,
stream=stream,
delta_stream=delta_stream,
)
else:
return self._chat(
messages,
stream=stream,
delta_stream=delta_stream,
)

if stream and delta_stream:
# No retry for delta streaming
output = _call_model_service()
elif stream and (not delta_stream):
output = retry_model_service_iterator(_call_model_service,
max_retries=self.max_retries)
else:
output = retry_model_service(_call_model_service,
max_retries=self.max_retries)

if isinstance(output, list):
output = self._postprocess_messages(output,
Expand Down Expand Up @@ -281,3 +311,70 @@ def _truncate_at_stop_word(text: str, stop: List[str]):
truncated = True
text = text[:k]
return truncated, text


def retry_model_service(
fn,
max_retries: int = 10,
exponential_base: float = 1.0,
):
"""Retry a function with exponential backoff"""

num_retries = 0
delay = 1.0
while True:
try:
return fn()

except ModelServiceError as e:
if max_retries <= 0: # no retry
raise e

# If harmful input or output detected, let it fail
if e.code == 'DataInspectionFailed':
raise e

print_traceback(is_error=False)

if num_retries >= max_retries:
raise ModelServiceError(exception=Exception(
f'Maximum number of retries ({max_retries}) exceeded.'))

num_retries += 1
delay *= exponential_base * (1.0 + random.random())
time.sleep(delay)


def retry_model_service_iterator(
it_fn,
max_retries: int = 10,
exponential_base: float = 1.0,
):
"""Retry an iterator with exponential backoff"""

num_retries = 0
delay = 1.0

while True:
try:
for rsp in it_fn():
yield rsp
break

except ModelServiceError as e:
if max_retries <= 0: # no retry
raise e

# If harmful input or output detected, let it fail
if e.code == 'DataInspectionFailed':
raise e

print_traceback(is_error=False)

if num_retries >= max_retries:
raise ModelServiceError(exception=Exception(
f'Maximum number of retries ({max_retries}) exceeded.'))

num_retries += 1
delay *= exponential_base * (1.0 + random.random())
time.sleep(delay)
6 changes: 3 additions & 3 deletions qwen_agent/llm/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _chat_stream(
Message(ASSISTANT, chunk.choices[0].delta.content)
]
except Exception as ex:
raise ModelServiceError(ex)
raise ModelServiceError(exception=ex)
else:
full_response = ''
for chunk in response:
Expand All @@ -80,7 +80,7 @@ def _chat_stream(
try:
full_response += chunk.choices[0].delta.content
except Exception as ex:
raise ModelServiceError(ex)
raise ModelServiceError(exception=ex)
yield [Message(ASSISTANT, full_response)]

def _chat_no_stream(self, messages: List[Message]) -> List[Message]:
Expand All @@ -93,4 +93,4 @@ def _chat_no_stream(self, messages: List[Message]) -> List[Message]:
try:
return [Message(ASSISTANT, response.choices[0].message.content)]
except Exception as ex:
raise ModelServiceError(ex)
raise ModelServiceError(exception=ex)
22 changes: 6 additions & 16 deletions qwen_agent/llm/qwen_dashscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,8 @@ def _chat_no_stream(
Message(ASSISTANT, response.output.choices[0].message.content)
]
else:
err = '\nError code: %s, error message: %s' % (
response.code,
response.message,
)
raise ModelServiceError(err)
raise ModelServiceError(code=response.code,
message=response.message)

def _chat_with_functions(
self,
Expand Down Expand Up @@ -102,11 +99,8 @@ def _text_completion_no_stream(
Message(ASSISTANT, response.output.choices[0].message.content)
]
else:
err = '\nError code: %s, error message: %s' % (
response.code,
response.message,
)
raise ModelServiceError(err)
raise ModelServiceError(code=response.code,
message=response.message)

def _text_completion_stream(
self,
Expand Down Expand Up @@ -169,9 +163,7 @@ def _delta_stream_output(response) -> Iterator[List[Message]]:
yield [Message(ASSISTANT, now_rsp)]
last_len = len(real_text)
else:
err = '\nError code: %s. Error message: %s' % (trunk.code,
trunk.message)
raise ModelServiceError(err)
raise ModelServiceError(code=trunk.code, message=trunk.message)
if text and (in_delay or (last_len != len(text))):
yield [Message(ASSISTANT, text[last_len:])]

Expand All @@ -183,6 +175,4 @@ def _full_stream_output(response) -> Iterator[List[Message]]:
Message(ASSISTANT, trunk.output.choices[0].message.content)
]
else:
err = '\nError code: %s. Error message: %s' % (trunk.code,
trunk.message)
raise ModelServiceError(err)
raise ModelServiceError(code=trunk.code, message=trunk.message)
13 changes: 3 additions & 10 deletions qwen_agent/llm/qwenvl_dashscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ def _chat_stream(
if trunk.status_code == HTTPStatus.OK:
yield _extract_vl_response(trunk)
else:
err = '\nError code: %s. Error message: %s' % (
trunk.code,
trunk.message,
)
raise ModelServiceError(err)
raise ModelServiceError(code=trunk.code, message=trunk.message)

def _chat_no_stream(
self,
Expand All @@ -69,11 +65,8 @@ def _chat_no_stream(
if response.status_code == HTTPStatus.OK:
return _extract_vl_response(response=response)
else:
err = '\nError code: %s, error message: %s' % (
response.code,
response.message,
)
raise ModelServiceError(err)
raise ModelServiceError(code=response.code,
message=response.message)

def _postprocess_messages(self, messages: List[Message],
fncall_mode: bool) -> List[Message]:
Expand Down
15 changes: 11 additions & 4 deletions qwen_agent/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@

import json5

from qwen_agent.utils.utils import logger

TOOL_REGISTRY = {}


def register_tool(name):
def register_tool(name, allow_overwrite=False):

def decorator(cls):
if name in TOOL_REGISTRY:
raise ValueError(
f'tool {name} has a duplicate name! Please ensure that the tool name is unique.'
)
if allow_overwrite:
logger.warning(
f'Tool `{name}` already exists! Overwriting with class {cls}.'
)
else:
raise ValueError(
f'Tool `{name}` already exists! Please ensure that the tool name is unique.'
)
cls.name = name
TOOL_REGISTRY[name] = cls

Expand Down
7 changes: 5 additions & 2 deletions qwen_agent/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ def hash_sha256(key):
return key


def print_traceback():
logger.error(''.join(traceback.format_exception(*sys.exc_info())))
def print_traceback(is_error=True):
if is_error:
logger.error(''.join(traceback.format_exception(*sys.exc_info())))
else:
logger.warning(''.join(traceback.format_exception(*sys.exc_info())))


def has_chinese_chars(data) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion tests/agents/test_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def test_assistant_files():
messages = [
Message('user', [
ContentItem(text='总结一个文章标题'),
ContentItem(file='https://github.com/QwenLM/Qwen-Agent')
ContentItem(
file=
'https://help.aliyun.com/zh/dashscope/developer-reference/api-details?disableWebsiteRedirect=true'
)
])
]

Expand Down
25 changes: 24 additions & 1 deletion tests/llm/test_dashscope.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from qwen_agent.llm import get_chat_model
from qwen_agent.llm import ModelServiceError, get_chat_model
from qwen_agent.llm.schema import Message

functions = [{
Expand Down Expand Up @@ -80,3 +80,26 @@ def test_llm_dashscope(functions, stream, delta_stream):
assert response[-1].function_call.name == 'image_gen'
else:
assert response[-1].function_call is None


@pytest.mark.parametrize('stream', [True, False])
@pytest.mark.parametrize('delta_stream', [True, False])
def test_llm_retry_failure(stream, delta_stream):
llm_cfg = {
'model': 'qwen-turbo',
'api_key': 'invalid',
'generate_cfg': {
'max_retries': 3
}
}

llm = get_chat_model(llm_cfg)
assert llm.max_retries == 3

messages = [Message('user', 'hello')]
with pytest.raises(ModelServiceError):
response = llm.chat(messages=messages,
stream=stream,
delta_stream=delta_stream)
if stream:
list(response)
1 change: 1 addition & 0 deletions tests/llm/test_oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_llm_oai(functions, stream, delta_stream):
}

llm = get_chat_model(llm_cfg)
assert llm.max_retries == 0

messages = [Message('user', 'draw a cute cat')]
response = llm.chat(messages=messages,
Expand Down

0 comments on commit b02331d

Please sign in to comment.