Skip to content

Commit

Permalink
Fix intervention handler none check (microsoft#4351)
Browse files Browse the repository at this point in the history
* Fix none check

* fix func

* fmt, lint

---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
jackgerrits and ekzhu authored Nov 25, 2024
1 parent 9b967fc commit 8347881
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,23 @@ def _stop_when_idle(self) -> bool:
return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle


def _warn_if_none(value: Any, handler_name: str) -> None:
"""
Utility function to check if the intervention handler returned None and issue a warning.
Args:
value: The return value to check
handler_name: Name of the intervention handler method for the warning message
"""
if value is None:
warnings.warn(
f"Intervention handler {handler_name} returned None. This might be unintentional. "
"Consider returning the original message or DropMessage explicitly.",
RuntimeWarning,
stacklevel=2,
)


class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(
self,
Expand Down Expand Up @@ -433,6 +450,7 @@ async def process_next(self) -> None:
):
try:
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
_warn_if_none(temp_message, "on_send")
except BaseException as e:
future.set_exception(e)
return
Expand All @@ -456,6 +474,7 @@ async def process_next(self) -> None:
):
try:
temp_message = await handler.on_publish(message, sender=sender)
_warn_if_none(temp_message, "on_publish")
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
Expand All @@ -474,6 +493,7 @@ async def process_next(self) -> None:
for handler in self._intervention_handlers:
try:
temp_message = await handler.on_response(message, sender=sender, recipient=recipient)
_warn_if_none(temp_message, "on_response")
except BaseException as e:
# TODO: should we raise the exception to sender of the response instead?
future.set_exception(e)
Expand Down
30 changes: 9 additions & 21 deletions python/packages/autogen-core/src/autogen_core/base/intervention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Any, Awaitable, Callable, Protocol, final

from autogen_core.base import AgentId
Expand All @@ -15,27 +14,15 @@
class DropMessage: ...


def _warn_if_none(value: Any, handler_name: str) -> None:
"""
Utility function to check if the intervention handler returned None and issue a warning.
Args:
value: The return value to check
handler_name: Name of the intervention handler method for the warning message
"""
if value is None:
warnings.warn(
f"Intervention handler {handler_name} returned None. This might be unintentional. "
"Consider returning the original message or DropMessage explicitly.",
RuntimeWarning,
stacklevel=2,
)


InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]]


class InterventionHandler(Protocol):
"""An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`autogen_core.base.AgentRuntime`.
Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly.
"""

async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ...
async def on_response(
Expand All @@ -44,14 +31,15 @@ async def on_response(


class DefaultInterventionHandler(InterventionHandler):
"""Simple class that provides a default implementation for all intervention
handler methods, that simply returns the message unchanged. Allows for easy
subclassing to override only the desired methods."""

async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:
_warn_if_none(message, "on_send")
return message

async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]:
_warn_if_none(message, "on_publish")
return message

async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
_warn_if_none(message, "on_response")
return message
2 changes: 0 additions & 2 deletions python/packages/autogen-core/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from typing import Annotated, List

import pytest

from autogen_core.base import CancellationToken
from autogen_core.components._function_utils import get_typed_signature
from autogen_core.components.tools import BaseTool, FunctionTool
from autogen_core.components.tools._base import ToolSchema

from pydantic import BaseModel, Field, model_serializer
from pydantic_core import PydanticUndefined

Expand Down

0 comments on commit 8347881

Please sign in to comment.