Skip to content

Commit

Permalink
Add test for context window truncation in agent controller (All-Hands…
Browse files Browse the repository at this point in the history
…-AI#6477)

Co-authored-by: Calvin Smith <[email protected]>
  • Loading branch information
csmith49 and Calvin Smith authored Jan 27, 2025
1 parent 5b53dbd commit 23348af
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from uuid import uuid4

import pytest
from litellm import ContextWindowExceededError

from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
Expand Down Expand Up @@ -552,3 +553,49 @@ def on_event(event: Event):
assert (
state.metrics.accumulated_cost == 10.0 * 3
), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}'


@pytest.mark.asyncio
async def test_context_window_exceeded_error_handling(mock_agent, mock_event_stream):
"""Test that context window exceeded errors are handled correctly by truncating history."""

class StepState:
def __init__(self):
self.has_errored = False

def step(self, state: State):
# Append a few messages to the history -- these will be truncated when we throw the error
state.history = [
MessageAction(content='Test message 0'),
MessageAction(content='Test message 1'),
]

error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
model='',
llm_provider='',
)
self.has_errored = True
raise error

state = StepState()
mock_agent.step = state.step

controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)

# Set the agent running and take a step in the controller -- this is similar
# to taking a single step using `run_controller`, but much easier to control
# termination for testing purposes
controller.state.agent_state = AgentState.RUNNING
await controller._step()

# Check that the error was thrown and the history has been truncated
assert state.has_errored
assert controller.state.history == [MessageAction(content='Test message 1')]

0 comments on commit 23348af

Please sign in to comment.