Skip to content

Commit

Permalink
fix: Enhance OpenAI client to handle additional stop reasons and impr…
Browse files Browse the repository at this point in the history
…ove tool call validation in tests to address empty tool_calls list. (microsoft#5223)

Resolves microsoft#5222
  • Loading branch information
ekzhu authored Jan 27, 2025
1 parent 8428462 commit b441d5b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
stop_reason = stop_reason.lower()

KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = {
"stop": "stop",
"length": "length",
"content_filter": "content_filter",
"function_calls": "function_calls",
"end_turn": "stop",
"tool_calls": "function_calls",
}
Expand Down Expand Up @@ -552,7 +556,7 @@ async def create(
content: Union[str, List[FunctionCall]]
if choice.message.function_call is not None:
raise ValueError("function_call is deprecated and is not supported by this model client.")
elif choice.message.tool_calls is not None:
elif choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0:
if choice.finish_reason != "tool_calls":
warnings.warn(
f"Finish reason mismatch: {choice.finish_reason} != tool_calls "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,25 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
# Should not be returning tool calls when the tool_calls are empty
ChatCompletion(
id="id5",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="I should make a tool call.",
tool_calls=[],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
Expand Down Expand Up @@ -652,6 +671,11 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
assert create_result.finish_reason == "function_calls"

# Should not be returning tool calls when the tool_calls are empty
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
assert create_result.content == "I should make a tool call."
assert create_result.finish_reason == "stop"


async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
# Test basic completion
Expand Down

0 comments on commit b441d5b

Please sign in to comment.