Skip to content

Commit

Permalink
Revamp max iteration Logic (crewAIInc#111)
Browse files Browse the repository at this point in the history
This now will allow to add a max_inter option to agents while also making sure to force the agent to give it's best final answer before running out of it's max_inter.
  • Loading branch information
joaomdmoura authored Jan 11, 2024
1 parent 8cc51d5 commit ea7759b
Show file tree
Hide file tree
Showing 6 changed files with 1,035 additions and 17 deletions.
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Agent(BaseModel):
goal: The objective of the agent.
backstory: The backstory of the agent.
llm: The language model that will run the agent.
max_iter: Maximum number of iterations for an agent to execute a task.
memory: Whether the agent should have memory or not.
verbose: Whether the agent execution should be in verbose mode.
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
Expand Down Expand Up @@ -72,6 +73,9 @@ class Agent(BaseModel):
tools: List[Any] = Field(
default_factory=list, description="Tools at agents disposal"
)
max_iter: Optional[int] = Field(
default=15, description="Maximum iterations for an agent to execute a task"
)
agent_executor: Optional[InstanceOf[CrewAgentExecutor]] = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
Expand Down Expand Up @@ -147,6 +151,7 @@ def __create_agent_executor(self) -> CrewAgentExecutor:
"tools": self.tools,
"verbose": self.verbose,
"handle_parsing_errors": True,
"max_iterations": self.max_iter,
}

if self.memory:
Expand Down
89 changes: 88 additions & 1 deletion src/crewai/agents/executor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,92 @@
from typing import Dict, Iterator, List, Optional, Tuple, Union
import time
from textwrap import dedent
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from langchain.agents import AgentExecutor
from langchain.agents.agent import ExceptionTool
from langchain.agents.tools import InvalidTool
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain_core.agents import AgentAction, AgentFinish, AgentStep
from langchain_core.exceptions import OutputParserException
from langchain_core.pydantic_v1 import root_validator
from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping

from ..tools.cache_tools import CacheTools
from .cache.cache_hit import CacheHit


class CrewAgentExecutor(AgentExecutor):
iterations: int = 0
max_iterations: Optional[int] = 15
force_answer_max_iterations: Optional[int] = None

@root_validator()
def set_force_answer_max_iterations(cls, values: Dict) -> Dict:
values["force_answer_max_iterations"] = values["max_iterations"] - 2
return values

def _should_force_answer(self) -> bool:
return True if self.iterations == self.force_answer_max_iterations else False

def _force_answer(self, output: AgentAction):
return AgentStep(
action=output,
observation=dedent(
"""\
I've used too many tools for this task.
I'm going to give you my absolute BEST Final answer now and
not use any more tools."""
),
)

def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green", "red"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the number of iterations and time elapsed
self.iterations = 0
time_elapsed = 0.0
start_time = time.time()
# We now enter the agent loop (until it returns something).
while self._should_continue(self.iterations, time_elapsed):
next_step_output = self._take_next_step(
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager=run_manager,
)
if isinstance(next_step_output, AgentFinish):
return self._return(
next_step_output, intermediate_steps, run_manager=run_manager
)

intermediate_steps.extend(next_step_output)
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
# See if tool should return directly
tool_return = self._get_tool_return(next_step_action)
if tool_return is not None:
return self._return(
tool_return, intermediate_steps, run_manager=run_manager
)
self.iterations += 1
time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return self._return(output, intermediate_steps, run_manager=run_manager)

def _iter_next_step(
self,
name_to_tool_map: Dict[str, BaseTool],
Expand All @@ -34,6 +108,14 @@ def _iter_next_step(
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
if self._should_force_answer():
if isinstance(output, AgentAction):
output = output
else:
output = output.action
yield self._force_answer(output)
return

except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
Expand Down Expand Up @@ -70,6 +152,11 @@ def _iter_next_step(
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)

if self._should_force_answer():
yield self._force_answer(output)
return

yield AgentStep(action=output, observation=observation)
return

Expand Down
65 changes: 55 additions & 10 deletions tests/agent_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Test Agent creation and execution basic functionality."""

from unittest.mock import patch

import pytest
from langchain.tools import tool
from langchain_openai import ChatOpenAI as OpenAI

from crewai.agent import Agent
from crewai.agents.cache import CacheHandler
from crewai.agents.executor import CrewAgentExecutor


def test_agent_creation():
Expand Down Expand Up @@ -79,8 +83,6 @@ def test_agent_execution():

@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execution_with_tools():
from langchain.tools import tool

@tool
def multiplier(numbers) -> float:
"""Useful for when you need to multiply two numbers together.
Expand All @@ -104,8 +106,6 @@ def multiplier(numbers) -> float:

@pytest.mark.vcr(filter_headers=["authorization"])
def test_logging_tool_usage():
from langchain.tools import tool

@tool
def multiplier(numbers) -> float:
"""Useful for when you need to multiply two numbers together.
Expand Down Expand Up @@ -137,10 +137,6 @@ def multiplier(numbers) -> float:

@pytest.mark.vcr(filter_headers=["authorization"])
def test_cache_hitting():
from unittest.mock import patch

from langchain.tools import tool

@tool
def multiplier(numbers) -> float:
"""Useful for when you need to multiply two numbers together.
Expand Down Expand Up @@ -182,8 +178,6 @@ def multiplier(numbers) -> float:

@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execution_with_specific_tools():
from langchain.tools import tool

@tool
def multiplier(numbers) -> float:
"""Useful for when you need to multiply two numbers together.
Expand All @@ -202,3 +196,54 @@ def multiplier(numbers) -> float:

output = agent.execute_task(task="What is 3 times 4", tools=[multiplier])
assert output == "3 times 4 is 12."


@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_custom_max_iterations():
@tool
def get_final_answer(numbers) -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
return 42

agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
max_iter=1,
allow_delegation=False,
)

with patch.object(
CrewAgentExecutor, "_iter_next_step", wraps=agent.agent_executor._iter_next_step
) as private_mock:
agent.execute_task(
task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
tools=[get_final_answer],
)
private_mock.assert_called_once()


@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_moved_on_after_max_iterations():
@tool
def get_final_answer(numbers) -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
return 42

agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
max_iter=3,
allow_delegation=False,
)

output = agent.execute_task(
task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
tools=[get_final_answer],
)
assert (
output == "I have used the tool multiple times and the final answer remains 42."
)
Loading

0 comments on commit ea7759b

Please sign in to comment.