diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 26c544dd2..6d813ef30 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -43,7 +43,13 @@ async def connect(self, websocket: WebSocket, session_id: str): agent = get_starting_agent() runner = RealtimeRunner(agent) - session_context = await runner.run() + session_context = await runner.run( + model_config={ + "initial_model_settings": { + "turn_detection": {"type": "server_vad", "idle_timeout_ms": 5000} + } + } + ) session = await session_context.__aenter__() self.active_sessions[session_id] = session self.session_contexts[session_id] = session_context diff --git a/pyproject.toml b/pyproject.toml index fa253a2eb..25d950b34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.2.9" +version = "0.2.10" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 02830bb29..f7f745690 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -33,6 +33,15 @@ output_guardrail, ) from .handoffs import Handoff, HandoffInputData, HandoffInputFilter, handoff +from .tool_guardrails import ( + ToolGuardrailFunctionOutput, + ToolInputGuardrail, + ToolInputGuardrailData, + ToolOutputGuardrail, + ToolOutputGuardrailData, + tool_input_guardrail, + tool_output_guardrail, +) from .items import ( HandoffCallItem, HandoffOutputItem, @@ -204,6 +213,13 @@ def enable_verbose_stdout_logging(): "GuardrailFunctionOutput", "input_guardrail", "output_guardrail", + "ToolInputGuardrail", + "ToolOutputGuardrail", + "ToolGuardrailFunctionOutput", + "ToolInputGuardrailData", + "ToolOutputGuardrailData", + "tool_input_guardrail", + "tool_output_guardrail", "handoff", "Handoff", "HandoffInputData", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 56784004c..fd700021e 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -80,6 +80,10 @@ Tool, ) from .tool_context import ToolContext +from .tool_guardrails import ( + ToolInputGuardrailData, + ToolOutputGuardrailData, +) from .tracing import ( SpanError, Trace, @@ -572,24 +576,64 @@ async def run_single_tool( if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: - _, _, result = await asyncio.gather( - hooks.on_tool_start(tool_context, agent, func_tool), - ( - agent.hooks.on_tool_start(tool_context, agent, func_tool) - if agent.hooks - else _coro.noop_coroutine() - ), - func_tool.on_invoke_tool(tool_context, tool_call.arguments), - ) + # 1) Run input tool guardrails, if any + final_result: Any | None = None + if func_tool.tool_input_guardrails: + for guardrail in func_tool.tool_input_guardrails: + gr_out = await guardrail.run( + ToolInputGuardrailData( + context=tool_context, + agent=agent, + tool_call=tool_call, + ) + ) + if gr_out.tripwire_triggered: + # Use the provided model message as the tool output + final_result = str(gr_out.model_message or "") + break + + if final_result is None: + # 2) Actually run the tool + await asyncio.gather( + hooks.on_tool_start(tool_context, agent, func_tool), + ( + agent.hooks.on_tool_start(tool_context, agent, func_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + real_result = await func_tool.on_invoke_tool( + tool_context, tool_call.arguments + ) - await asyncio.gather( - hooks.on_tool_end(tool_context, agent, func_tool, result), - ( - agent.hooks.on_tool_end(tool_context, agent, func_tool, result) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + # 3) Run output tool guardrails, if any + final_result = real_result + if func_tool.tool_output_guardrails: + for guardrail in func_tool.tool_output_guardrails: + gr_out = await guardrail.run( + ToolOutputGuardrailData( + context=tool_context, + agent=agent, + tool_call=tool_call, + output=real_result, + ) + ) + if gr_out.tripwire_triggered: + final_result = str(gr_out.model_message or "") + break + + # 4) Tool end hooks (with final result, which may have been overridden) + await asyncio.gather( + hooks.on_tool_end(tool_context, agent, func_tool, final_result), + ( + agent.hooks.on_tool_end( + tool_context, agent, func_tool, final_result + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + result = final_result except Exception as e: _error_tracing.attach_error_to_current_span( SpanError( diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 766c49f8d..48ff74049 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -297,7 +297,11 @@ async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None: """Send a raw message to the model.""" assert self._websocket is not None, "Not connected" - await self._websocket.send(event.model_dump_json(exclude_none=True, exclude_unset=True)) + json_str = event.model_dump_json(exclude_none=True, exclude_unset=True) + + logger.debug(f"ZZZZZ Sending raw message of type {event.type}. Length: {len(json_str)}") + + await self._websocket.send(json_str) async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None: converted = _ConversionHelper.convert_user_input_to_item_create(event) @@ -306,6 +310,13 @@ async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None: async def _send_audio(self, event: RealtimeModelSendAudio) -> None: converted = _ConversionHelper.convert_audio_to_input_audio_buffer_append(event) + input_audio_len = len(event.audio) + b64_audio_len = len(base64.b64encode(event.audio).decode("utf-8")) + logger.debug( + f"ZZZZZ Sending audio of length {input_audio_len}. " + f"Base64 encoded length: {b64_audio_len}" + ) + await self._send_raw_message(converted) if event.commit: await self._send_raw_message( diff --git a/src/agents/tool.py b/src/agents/tool.py index 4624fbb52..bcbbce98d 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -94,6 +94,13 @@ class FunctionTool: and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool based on your context/state.""" + # Tool-specific guardrails + tool_input_guardrails: list["ToolInputGuardrail[Any]"] | None = None + """Optional list of input guardrails to run before invoking this tool.""" + + tool_output_guardrails: list["ToolOutputGuardrail[Any]"] | None = None + """Optional list of output guardrails to run after invoking this tool.""" + def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) diff --git a/src/agents/tool_guardrails.py b/src/agents/tool_guardrails.py new file mode 100644 index 000000000..ed932065b --- /dev/null +++ b/src/agents/tool_guardrails.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import inspect +from collections.abc import Awaitable +from dataclasses import dataclass +from typing import Any, Callable, Generic, Optional, overload + +from typing_extensions import TypeVar + +from .agent import Agent +from .tool_context import ToolContext +from .util._types import MaybeAwaitable +from openai.types.responses import ResponseFunctionToolCall + + +@dataclass +class ToolGuardrailFunctionOutput: + """The output of a tool guardrail function. + + - `output_info`: Optional data about checks performed. + - `tripwire_triggered`: Whether the guardrail was tripped. + - `model_message`: Message to send back to the model as the tool output if tripped. + """ + + output_info: Any + tripwire_triggered: bool + model_message: Optional[str] = None + + +@dataclass +class ToolInputGuardrailData: + """Input data passed to a tool input guardrail function.""" + + context: ToolContext[Any] + agent: Agent[Any] + tool_call: ResponseFunctionToolCall + + +@dataclass +class ToolOutputGuardrailData(ToolInputGuardrailData): + """Input data passed to a tool output guardrail function. + + Extends input data with the tool's output. + """ + + output: Any + + +TContext_co = TypeVar("TContext_co", bound=Any, covariant=True) + + +@dataclass +class ToolInputGuardrail(Generic[TContext_co]): + """A guardrail that runs before a function tool is invoked.""" + + guardrail_function: Callable[[ToolInputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput]] + name: str | None = None + + def get_name(self) -> str: + return self.name or self.guardrail_function.__name__ + + async def run( + self, data: ToolInputGuardrailData + ) -> ToolGuardrailFunctionOutput: + result = self.guardrail_function(data) + if inspect.isawaitable(result): + return await result # type: ignore[return-value] + return result # type: ignore[return-value] + + +@dataclass +class ToolOutputGuardrail(Generic[TContext_co]): + """A guardrail that runs after a function tool is invoked.""" + + guardrail_function: Callable[[ToolOutputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput]] + name: str | None = None + + def get_name(self) -> str: + return self.name or self.guardrail_function.__name__ + + async def run( + self, data: ToolOutputGuardrailData + ) -> ToolGuardrailFunctionOutput: + result = self.guardrail_function(data) + if inspect.isawaitable(result): + return await result # type: ignore[return-value] + return result # type: ignore[return-value] + + +# Decorators +_ToolInputFuncSync = Callable[[ToolInputGuardrailData], ToolGuardrailFunctionOutput] +_ToolInputFuncAsync = Callable[[ToolInputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]] + + +@overload +def tool_input_guardrail(func: _ToolInputFuncSync): # type: ignore[overload-overlap] + ... + + +@overload +def tool_input_guardrail(func: _ToolInputFuncAsync): # type: ignore[overload-overlap] + ... + + +@overload +def tool_input_guardrail(*, name: str | None = None) -> Callable[[ + _ToolInputFuncSync | _ToolInputFuncAsync +], ToolInputGuardrail[Any]]: ... + + +def tool_input_guardrail( + func: _ToolInputFuncSync | _ToolInputFuncAsync | None = None, + *, + name: str | None = None, +) -> ToolInputGuardrail[Any] | Callable[[ + _ToolInputFuncSync | _ToolInputFuncAsync +], ToolInputGuardrail[Any]]: + """Decorator to create a ToolInputGuardrail from a function.""" + + def decorator(f: _ToolInputFuncSync | _ToolInputFuncAsync) -> ToolInputGuardrail[Any]: + return ToolInputGuardrail(guardrail_function=f, name=name or f.__name__) + + if func is not None: + return decorator(func) + return decorator + + +_ToolOutputFuncSync = Callable[[ToolOutputGuardrailData], ToolGuardrailFunctionOutput] +_ToolOutputFuncAsync = Callable[[ToolOutputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]] + + +@overload +def tool_output_guardrail(func: _ToolOutputFuncSync): # type: ignore[overload-overlap] + ... + + +@overload +def tool_output_guardrail(func: _ToolOutputFuncAsync): # type: ignore[overload-overlap] + ... + + +@overload +def tool_output_guardrail(*, name: str | None = None) -> Callable[[ + _ToolOutputFuncSync | _ToolOutputFuncAsync +], ToolOutputGuardrail[Any]]: ... + + +def tool_output_guardrail( + func: _ToolOutputFuncSync | _ToolOutputFuncAsync | None = None, + *, + name: str | None = None, +) -> ToolOutputGuardrail[Any] | Callable[[ + _ToolOutputFuncSync | _ToolOutputFuncAsync +], ToolOutputGuardrail[Any]]: + """Decorator to create a ToolOutputGuardrail from a function.""" + + def decorator(f: _ToolOutputFuncSync | _ToolOutputFuncAsync) -> ToolOutputGuardrail[Any]: + return ToolOutputGuardrail(guardrail_function=f, name=name or f.__name__) + + if func is not None: + return decorator(func) + return decorator + diff --git a/tests/test_tool_guardrails.py b/tests/test_tool_guardrails.py new file mode 100644 index 000000000..fa33df270 --- /dev/null +++ b/tests/test_tool_guardrails.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import json +from typing import Any + +import pytest + +from agents import ( + Agent, + FunctionTool, + RunContextWrapper, + Runner, + ToolGuardrailFunctionOutput, + tool_input_guardrail, + tool_output_guardrail, +) +from agents.tool import function_tool +from agents.items import ToolCallOutputItem + +from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message + + +@pytest.mark.asyncio +async def test_tool_input_guardrail_blocks_and_uses_message(): + executed = {"called": False} + + @function_tool(name_override="guarded_tool") + def guarded_tool() -> str: + executed["called"] = True + return "real_result" + + @tool_input_guardrail + def input_gr(data) -> ToolGuardrailFunctionOutput: + # Always block with model message + return ToolGuardrailFunctionOutput( + output_info=None, + tripwire_triggered=True, + model_message="blocked_by_input_guardrail", + ) + + # Attach to FunctionTool + assert isinstance(guarded_tool, FunctionTool) + guarded_tool.tool_input_guardrails = [input_gr] + + model = FakeModel() + agent = Agent(name="test", model=model, tools=[guarded_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("guarded_tool", json.dumps({}))], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="start") + + # Tool should not run + assert executed["called"] is False + # Tool output item should contain the guardrail model message + tool_outputs = [it for it in result.new_items if isinstance(it, ToolCallOutputItem)] + assert len(tool_outputs) == 1 + # raw_item.output is the string sent to the model + assert tool_outputs[0].raw_item["output"] == "blocked_by_input_guardrail" + assert result.final_output == "done" + + +@pytest.mark.asyncio +async def test_tool_output_guardrail_replaces_result(): + executed = {"called": False} + + @function_tool(name_override="guarded_tool_out") + def guarded_tool_out() -> str: + executed["called"] = True + return "real_output" + + @tool_output_guardrail + def output_gr(data) -> ToolGuardrailFunctionOutput: + # Replace result + return ToolGuardrailFunctionOutput( + output_info=None, + tripwire_triggered=True, + model_message="overridden_by_output_guardrail", + ) + + assert isinstance(guarded_tool_out, FunctionTool) + guarded_tool_out.tool_output_guardrails = [output_gr] + + model = FakeModel() + agent = Agent(name="test", model=model, tools=[guarded_tool_out]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("guarded_tool_out", json.dumps({}))], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="go") + + assert executed["called"] is True + tool_outputs = [it for it in result.new_items if isinstance(it, ToolCallOutputItem)] + assert len(tool_outputs) == 1 + assert tool_outputs[0].raw_item["output"] == "overridden_by_output_guardrail" + assert result.final_output == "done" + + +@pytest.mark.asyncio +async def test_input_guardrail_takes_precedence_over_output_guardrail(): + executed = {"called": False} + + @function_tool(name_override="both_guarded") + def both_guarded() -> str: + executed["called"] = True + return "should_not_matter" + + @tool_input_guardrail + def input_gr(data) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info=None, + tripwire_triggered=True, + model_message="input_wins", + ) + + @tool_output_guardrail + def output_gr(data) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info=None, + tripwire_triggered=True, + model_message="output_would_win_if_reached", + ) + + assert isinstance(both_guarded, FunctionTool) + both_guarded.tool_input_guardrails = [input_gr] + both_guarded.tool_output_guardrails = [output_gr] + + model = FakeModel() + agent = Agent(name="test", model=model, tools=[both_guarded]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("both_guarded", json.dumps({}))], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="go") + + # Input guardrail should prevent tool from running + assert executed["called"] is False + tool_outputs = [it for it in result.new_items if isinstance(it, ToolCallOutputItem)] + assert len(tool_outputs) == 1 + assert tool_outputs[0].raw_item["output"] == "input_wins" + assert result.final_output == "done" + diff --git a/uv.lock b/uv.lock index 6827baac9..12a50a794 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.11'", @@ -1816,7 +1816,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.2.9" +version = "0.2.10" source = { editable = "." } dependencies = [ { name = "griffe" },