Skip to content

Commit

Permalink
Add git patch info to guess_success prompt (All-Hands-AI#5950)
Browse files Browse the repository at this point in the history
Co-authored-by: openhands <[email protected]>
  • Loading branch information
neubig and openhands-agent authored Jan 4, 2025
1 parent 510c164 commit 5bdebac
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 15 deletions.
24 changes: 16 additions & 8 deletions openhands/resolver/issue_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def get_instruction(

@abstractmethod
def guess_success(
self, issue: GithubIssue, history: list[Event]
self, issue: GithubIssue, history: list[Event], git_patch: str | None = None
) -> tuple[bool, list[bool] | None, str]:
"""Guess if the issue has been resolved based on the agent's output."""
"""Guess if the issue has been resolved based on the agent's output and git patch."""
pass


Expand Down Expand Up @@ -249,13 +249,14 @@ def get_instruction(
)

def guess_success(
self, issue: GithubIssue, history: list[Event]
self, issue: GithubIssue, history: list[Event], git_patch: str | None = None
) -> tuple[bool, None | list[bool], str]:
"""Guess if the issue is fixed based on the history and the issue description.
Args:
issue: The issue to check
history: The agent's history
git_patch: Optional git patch showing the changes made
"""
last_message = history[-1].message

Expand Down Expand Up @@ -665,6 +666,7 @@ def _check_review_thread(
review_thread: ReviewThread,
issues_context: str,
last_message: str,
git_patch: str | None = None,
) -> tuple[bool, str]:
"""Check if a review thread's feedback has been addressed."""
files_context = json.dumps(review_thread.files, indent=4)
Expand All @@ -683,6 +685,7 @@ def _check_review_thread(
feedback=review_thread.comment,
files_context=files_context,
last_message=last_message,
git_patch=git_patch or 'No changes made yet',
)

return self._check_feedback_with_llm(prompt)
Expand All @@ -692,6 +695,7 @@ def _check_thread_comments(
thread_comments: list[str],
issues_context: str,
last_message: str,
git_patch: str | None = None,
) -> tuple[bool, str]:
"""Check if thread comments feedback has been addressed."""
thread_context = '\n---\n'.join(thread_comments)
Expand All @@ -708,6 +712,7 @@ def _check_thread_comments(
issue_context=issues_context,
thread_context=thread_context,
last_message=last_message,
git_patch=git_patch or 'No changes made yet',
)

return self._check_feedback_with_llm(prompt)
Expand All @@ -717,6 +722,7 @@ def _check_review_comments(
review_comments: list[str],
issues_context: str,
last_message: str,
git_patch: str | None = None,
) -> tuple[bool, str]:
"""Check if review comments feedback has been addressed."""
review_context = '\n---\n'.join(review_comments)
Expand All @@ -733,15 +739,17 @@ def _check_review_comments(
issue_context=issues_context,
review_context=review_context,
last_message=last_message,
git_patch=git_patch or 'No changes made yet',
)

return self._check_feedback_with_llm(prompt)

def guess_success(
self, issue: GithubIssue, history: list[Event]
self, issue: GithubIssue, history: list[Event], git_patch: str | None = None
) -> tuple[bool, None | list[bool], str]:
"""Guess if the issue is fixed based on the history and the issue description."""
"""Guess if the issue is fixed based on the history, issue description and git patch."""
last_message = history[-1].message

issues_context = json.dumps(issue.closing_issues, indent=4)
success_list = []
explanation_list = []
Expand All @@ -751,7 +759,7 @@ def guess_success(
for review_thread in issue.review_threads:
if issues_context and last_message:
success, explanation = self._check_review_thread(
review_thread, issues_context, last_message
review_thread, issues_context, last_message, git_patch
)
else:
success, explanation = False, 'Missing context or message'
Expand All @@ -761,7 +769,7 @@ def guess_success(
elif issue.thread_comments:
if issue.thread_comments and issues_context and last_message:
success, explanation = self._check_thread_comments(
issue.thread_comments, issues_context, last_message
issue.thread_comments, issues_context, last_message, git_patch
)
else:
success, explanation = (
Expand All @@ -774,7 +782,7 @@ def guess_success(
# Handle PRs with only review comments (no file-specific review comments or thread comments)
if issue.review_comments and issues_context and last_message:
success, explanation = self._check_review_comments(
issue.review_comments, issues_context, last_message
issue.review_comments, issues_context, last_message, git_patch
)
else:
success, explanation = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Feedback:
Files locations:
{{ files_context }}

Changes made (git patch):
{{ git_patch }}

Last message from AI agent:
{{ last_message }}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ Issue descriptions:
PR Review Comments:
{{ review_context }}

Changes made (git patch):
{{ git_patch }}

Last message from AI agent:
{{ last_message }}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ Issue descriptions:
PR Thread Comments:
{{ thread_context }}

Changes made (git patch):
{{ git_patch }}

Last message from AI agent:
{{ last_message }}

Expand Down
4 changes: 2 additions & 2 deletions openhands/resolver/resolve_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def on_event(evt):
else:
histories = [dataclasses.asdict(event) for event in state.history]
metrics = state.metrics.get() if state.metrics else None
# determine success based on the history and the issue description
# determine success based on the history, issue description and git patch
success, comment_success, result_explanation = issue_handler.guess_success(
issue, state.history
issue, state.history, git_patch
)

if issue_handler.issue_type == 'pr' and comment_success:
Expand Down
195 changes: 190 additions & 5 deletions tests/unit/resolver/test_pr_handler_guess_success.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
import json
from unittest.mock import MagicMock, patch

import pytest

from openhands.core.config import LLMConfig
from openhands.events.action.message import MessageAction
from openhands.llm.llm import LLM
from openhands.resolver.github_issue import GithubIssue, ReviewThread
from openhands.resolver.issue_definitions import PRHandler


def mock_llm_response(content):
"""Helper function to create a mock LLM response."""
mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content=content))]
return mock_response
@pytest.fixture
def pr_handler():
llm_config = LLMConfig(model='test-model')
return PRHandler('test-owner', 'test-repo', 'test-token', llm_config)


@pytest.fixture
def mock_llm_success_response():
return MagicMock(
choices=[
MagicMock(
message=MagicMock(
content="""--- success
true
--- explanation
The changes look good"""
)
)
]
)


def test_guess_success_review_threads_litellm_call():
Expand Down Expand Up @@ -233,6 +251,63 @@ def test_check_feedback_with_llm():
assert (success, explanation) == case['expected']


def test_check_review_thread_with_git_patch():
"""Test that git patch from complete_runtime is included in the prompt."""
# Create a PR handler instance
llm_config = LLMConfig(model='test', api_key='test')
handler = PRHandler('test-owner', 'test-repo', 'test-token', llm_config)

# Create test data
review_thread = ReviewThread(
comment='Please fix the formatting\n---\nlatest feedback:\nAdd docstrings',
files=['/src/file1.py', '/src/file2.py'],
)
issues_context = json.dumps(
['Issue 1 description', 'Issue 2 description'], indent=4
)
last_message = 'I have fixed the formatting and added docstrings'
git_patch = 'diff --git a/src/file1.py b/src/file1.py\n+"""Added docstring."""\n'

# Mock the LLM response
mock_response = MagicMock()
mock_response.choices = [
MagicMock(
message=MagicMock(
content="""--- success
true
--- explanation
Changes look good"""
)
)
]

# Test the function
with patch.object(LLM, 'completion') as mock_completion:
mock_completion.return_value = mock_response
success, explanation = handler._check_review_thread(
review_thread, issues_context, last_message, git_patch
)

# Verify the completion() call
mock_completion.assert_called_once()
call_args = mock_completion.call_args
prompt = call_args[1]['messages'][0]['content']

# Check prompt content
assert 'Issue descriptions:\n' + issues_context in prompt
assert 'Feedback:\n' + review_thread.comment in prompt
assert (
'Files locations:\n' + json.dumps(review_thread.files, indent=4) in prompt
)
assert 'Last message from AI agent:\n' + last_message in prompt
assert 'Changes made (git patch):\n' + git_patch in prompt

# Check result
assert success is True
assert explanation == 'Changes look good'


def test_check_review_thread():
"""Test the _check_review_thread helper function."""
# Create a PR handler instance
Expand Down Expand Up @@ -288,6 +363,61 @@ def test_check_review_thread():
assert explanation == 'Changes look good'


def test_check_thread_comments_with_git_patch():
"""Test that git patch from complete_runtime is included in the prompt."""
# Create a PR handler instance
llm_config = LLMConfig(model='test', api_key='test')
handler = PRHandler('test-owner', 'test-repo', 'test-token', llm_config)

# Create test data
thread_comments = [
'Please improve error handling',
'Add input validation',
'latest feedback:\nHandle edge cases',
]
issues_context = json.dumps(
['Issue 1 description', 'Issue 2 description'], indent=4
)
last_message = 'I have added error handling and input validation'
git_patch = 'diff --git a/src/file1.py b/src/file1.py\n+try:\n+ validate_input()\n+except ValueError:\n+ handle_error()\n'

# Mock the LLM response
mock_response = MagicMock()
mock_response.choices = [
MagicMock(
message=MagicMock(
content="""--- success
true
--- explanation
Changes look good"""
)
)
]

# Test the function
with patch.object(LLM, 'completion') as mock_completion:
mock_completion.return_value = mock_response
success, explanation = handler._check_thread_comments(
thread_comments, issues_context, last_message, git_patch
)

# Verify the completion() call
mock_completion.assert_called_once()
call_args = mock_completion.call_args
prompt = call_args[1]['messages'][0]['content']

# Check prompt content
assert 'Issue descriptions:\n' + issues_context in prompt
assert 'PR Thread Comments:\n' + '\n---\n'.join(thread_comments) in prompt
assert 'Last message from AI agent:\n' + last_message in prompt
assert 'Changes made (git patch):\n' + git_patch in prompt

# Check result
assert success is True
assert explanation == 'Changes look good'


def test_check_thread_comments():
"""Test the _check_thread_comments helper function."""
# Create a PR handler instance
Expand Down Expand Up @@ -341,6 +471,61 @@ def test_check_thread_comments():
assert explanation == 'Changes look good'


def test_check_review_comments_with_git_patch():
"""Test that git patch from complete_runtime is included in the prompt."""
# Create a PR handler instance
llm_config = LLMConfig(model='test', api_key='test')
handler = PRHandler('test-owner', 'test-repo', 'test-token', llm_config)

# Create test data
review_comments = [
'Please fix the code style',
'Add more test cases',
'latest feedback:\nImprove documentation',
]
issues_context = json.dumps(
['Issue 1 description', 'Issue 2 description'], indent=4
)
last_message = 'I have fixed the code style and added tests'
git_patch = 'diff --git a/src/file1.py b/src/file1.py\n+"""This module does X."""\n+def func():\n+ """Do Y."""\n'

# Mock the LLM response
mock_response = MagicMock()
mock_response.choices = [
MagicMock(
message=MagicMock(
content="""--- success
true
--- explanation
Changes look good"""
)
)
]

# Test the function
with patch.object(LLM, 'completion') as mock_completion:
mock_completion.return_value = mock_response
success, explanation = handler._check_review_comments(
review_comments, issues_context, last_message, git_patch
)

# Verify the completion() call
mock_completion.assert_called_once()
call_args = mock_completion.call_args
prompt = call_args[1]['messages'][0]['content']

# Check prompt content
assert 'Issue descriptions:\n' + issues_context in prompt
assert 'PR Review Comments:\n' + '\n---\n'.join(review_comments) in prompt
assert 'Last message from AI agent:\n' + last_message in prompt
assert 'Changes made (git patch):\n' + git_patch in prompt

# Check result
assert success is True
assert explanation == 'Changes look good'


def test_check_review_comments():
"""Test the _check_review_comments helper function."""
# Create a PR handler instance
Expand Down

0 comments on commit 5bdebac

Please sign in to comment.