diff --git a/mentat/parsers/block_parser.py b/mentat/parsers/block_parser.py index c54679ae2..2b3f7df36 100644 --- a/mentat/parsers/block_parser.py +++ b/mentat/parsers/block_parser.py @@ -206,7 +206,7 @@ def file_edits_to_llm_message(self, parsedLLMResponse: ParsedLLMResponse) -> str Inverse of stream_and_parse_llm_response """ git_root = GIT_ROOT.get() - ans = parsedLLMResponse.conversation + ans = parsedLLMResponse.conversation.strip() + "\n\n" for file_edit in parsedLLMResponse.file_edits: tmp = {} tmp[_BlockParserJsonKeys.File.value] = file_edit.file_path.relative_to( diff --git a/mentat/parsers/git_parser.py b/mentat/parsers/git_parser.py new file mode 100644 index 000000000..493288256 --- /dev/null +++ b/mentat/parsers/git_parser.py @@ -0,0 +1,104 @@ +from pathlib import Path +from textwrap import dedent +from typing import Any, AsyncGenerator, List + +from mentat.git_handler import GIT_ROOT +from mentat.llm_api import chunk_to_lines +from mentat.parsers.file_edit import FileEdit, Replacement +from mentat.parsers.parser import ParsedLLMResponse + +# The git diff format should not be used with an LLM because it contains information like SHAs +# which it would not know about. It also involves certain arithmetic that it couldn't reliably do +# and typically prints out extra lines which aid human readability but would cost tokens. +# It is implemented to create training data from git diffs. Therefore there's not need for it to +# work asynchronously or stream partial results. In theory one could work this into the existing +# Parser class but it is simpler to assume one has the whole string up front and make heavy use of +# split. + + +class GitParser: + # This doesn't actually "stream and parse" but it is named this way to match the interface of + # the production parsers for use in the translation script. + async def stream_and_parse_llm_response( + self, + response: AsyncGenerator[Any, None], + ) -> ParsedLLMResponse: + string = "" + async for chunk in response: + for content in chunk_to_lines(chunk): + string += content + return self.parse_string(string) + + def parse_string(self, git_diff: str) -> ParsedLLMResponse: + git_root = GIT_ROOT.get() + # This is safe because actual code is prepended with ' ', + or -. + split_on_diff = git_diff.split("\ndiff --git ") + + # Use commit message for conversation + commit_message = dedent(split_on_diff[0].split("\n\n")[1].strip()) + + file_edits: List[FileEdit] = [] + for diff in split_on_diff[1:]: + is_creation = "new file mode" in diff + is_deletion = "deleted file mode" in diff + first_line = diff.split("\n")[0] + start_file_name = Path(first_line.split()[0][2:]).resolve() + end_file_name = Path(first_line.split()[1][2:]).resolve() + if start_file_name != end_file_name: + new_name = end_file_name + else: + new_name = None + + file_edit = FileEdit( + git_root / start_file_name, + [], + is_creation=is_creation, + is_deletion=is_deletion, + rename_file_path=new_name, + ) + diff_split = diff.split("\n@@") + if not is_deletion: + for change in diff_split[1:]: + line_info = change.split("@@")[0] + start_line = int(line_info.split()[0].split(",")[0][1:]) - 1 + end_line = start_line + int(line_info.split()[0].split(",")[1]) + line_changes = change.split("@@")[1] + code_lines = line_changes.split("\n") + # This check is necessary because new code sometimes starts on the same line + # as @@ sometimes on the next line. + if code_lines[0] == "": + code_lines = code_lines[1:] + if code_lines[-1] == "": + code_lines = code_lines[:-1] + + # Git diff gives context for human readability we don't want to train the llm + # to produce. + starting_repetition = 0 + for line in code_lines: + if line.startswith(" "): + starting_repetition += 1 + else: + break + ending_repetition = 0 + for line in reversed(code_lines): + if line.startswith(" "): + ending_repetition += 1 + else: + break + start_line += starting_repetition + end_line -= ending_repetition + + lines: List[str] = [] + for line in code_lines[ + starting_repetition : len(code_lines) - ending_repetition + ]: + if not line.startswith("-"): + lines.append(line[1:]) + + file_edit.replacements.append( + Replacement(start_line, end_line, lines) + ) + + file_edits.append(file_edit) + + return ParsedLLMResponse(git_diff, commit_message, file_edits) diff --git a/mentat/parsers/parser.py b/mentat/parsers/parser.py index 0dae137db..2c2ce27df 100644 --- a/mentat/parsers/parser.py +++ b/mentat/parsers/parser.py @@ -235,6 +235,7 @@ async def stream_and_parse_llm_response( cur_file_edit.rename_file_path = ( file_edit.rename_file_path ) + cur_file_edit.replacements.extend(file_edit.replacements) file_edit = cur_file_edit # Print file header diff --git a/mentat/parsers/replacement_parser.py b/mentat/parsers/replacement_parser.py index 7f93fa069..a69fbd0a1 100644 --- a/mentat/parsers/replacement_parser.py +++ b/mentat/parsers/replacement_parser.py @@ -126,7 +126,7 @@ def file_edits_to_llm_message(self, parsedLLMResponse: ParsedLLMResponse) -> str Inverse of stream_and_parse_llm_response """ git_root = GIT_ROOT.get() - ans = parsedLLMResponse.conversation + ans = parsedLLMResponse.conversation.strip() + "\n\n" for file_edit in parsedLLMResponse.file_edits: action_indicator = "" if file_edit.is_creation: diff --git a/scripts/translate_transcript.py b/scripts/translate_transcript.py index d8fda9063..616edef90 100755 --- a/scripts/translate_transcript.py +++ b/scripts/translate_transcript.py @@ -2,11 +2,13 @@ import argparse import asyncio import json +from pathlib import Path from unittest.mock import AsyncMock from mentat.code_file_manager import CODE_FILE_MANAGER, CodeFileManager from mentat.git_handler import GIT_ROOT from mentat.parsers.block_parser import BlockParser +from mentat.parsers.git_parser import GitParser from mentat.parsers.parser import Parser from mentat.parsers.replacement_parser import ReplacementParser from mentat.parsers.split_diff_parser import SplitDiffParser @@ -14,53 +16,61 @@ from mentat.session_stream import SESSION_STREAM from mentat.utils import convert_string_to_asyncgen -CODE_FILE_MANAGER.set(CodeFileManager()) -SESSION_STREAM.set(AsyncMock()) - - parser_map: dict[str, Parser] = { "block": BlockParser(), "replacement": ReplacementParser(), "split-diff": SplitDiffParser(), "unified-diff": UnifiedDiffParser(), + "git": GitParser(), } -parser = argparse.ArgumentParser( - description="Translate transcript between parsing formats" -) -parser.add_argument( - "--transcript", type=str, default=None, help="Transcript to translate" -) -# TODO: infer from config or something -parser.add_argument( - "--starting-format", - type=str, - default="block", - help="Format of the transcript to translate", -) -parser.add_argument( - "--ending-format", type=str, default="block", help="Format to translate to" -) -parser.add_argument("--git-root", type=str, default=".", help="Git root directory") -args = parser.parse_args() -GIT_ROOT.set(args.git_root) -starting_parser = parser_map[args.starting_format] -ending_parser = parser_map[args.ending_format] +def translate_message(message: str, starting_parser, ending_parser) -> str: + parsedLLMResponse = asyncio.run( + starting_parser.stream_and_parse_llm_response( + convert_string_to_asyncgen(message, 100) + ) + ) + return ending_parser.file_edits_to_llm_message(parsedLLMResponse) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Translate transcript between parsing formats" + ) + parser.add_argument( + "--transcript", type=str, default=None, help="Transcript to translate" + ) + parser.add_argument( + "--starting-format", + type=str, + default="block", + help="Format of the transcript to translate", + ) + parser.add_argument( + "--ending-format", type=str, default="block", help="Format to translate to" + ) + parser.add_argument("--git-root", type=str, default=".", help="Git root directory") + args = parser.parse_args() + + GIT_ROOT.set(Path(args.git_root)) + CODE_FILE_MANAGER.set(CodeFileManager()) + SESSION_STREAM.set(AsyncMock()) -with open(args.transcript, "r") as f: - for line in f.readlines(): - transcript = json.loads(line) - messages = transcript["messages"] - for message in messages: - # Note we don't change the system prompts. In training they are stripped off anyway. - if message["role"] == "assistant": - content = message["content"] - file_edits = asyncio.run( - starting_parser.stream_and_parse_llm_response( - convert_string_to_asyncgen(content, 100) - ) - ) - message["content"] = ending_parser.file_edits_to_llm_message(file_edits) + starting_parser = parser_map[args.starting_format] + ending_parser = parser_map[args.ending_format] + with open(args.transcript, "r") as f: + if ".json" in args.transcript: + for line in f.readlines(): + transcript = json.loads(line) + messages = transcript["messages"] + for message in messages: + # Note we don't change the system prompts. In training they are stripped off anyway. + if message["role"] == "assistant": + message["content"] = translate_message( + message["content"], starting_parser, ending_parser + ) - print(json.dumps(transcript)) + print(json.dumps(transcript)) + else: + print(translate_message(f.read(), starting_parser, ending_parser)) diff --git a/testbed/format_examples/block.txt b/testbed/format_examples/block.txt new file mode 100644 index 000000000..3a78cfac6 --- /dev/null +++ b/testbed/format_examples/block.txt @@ -0,0 +1,40 @@ +Fifth commit + +@@start +{ + "file": "a", + "action": "replace", + "start-line": 2, + "end-line": 11 +} +@@code +3 +4 +5 +6 +0 +10 +11 +12 +@@end +@@start +{ + "file": "d", + "action": "delete-file" +} +@@end +@@start +{ + "file": "c", + "action": "rename-file", + "name": "e" +} +@@end +@@start +{ + "file": "c", + "action": "delete", + "start-line": 8, + "end-line": 8 +} +@@end diff --git a/testbed/format_examples/git_diff.txt b/testbed/format_examples/git_diff.txt new file mode 100644 index 000000000..6fad98ad5 --- /dev/null +++ b/testbed/format_examples/git_diff.txt @@ -0,0 +1,57 @@ +commit 07aa8bae2d4170767b427a56e044f393e86b462f +Author: Jake Koenig +Date: Thu Oct 19 07:28:08 2023 -0700 + + Fifth commit + +diff --git a/a b/a +index ec588a0..248c392 100644 +--- a/a ++++ b/a +@@ -1,11 +1,9 @@ + 1 +-2 + 3 + 4 + 5 + 6 +-7 +-8 +-9 + 0 +- ++10 ++11 ++12 +diff --git a/d b/d +deleted file mode 100644 +index ec588a0..0000000 +--- a/d ++++ /dev/null +@@ -1,11 +0,0 @@ +-1 +-2 +-3 +-4 +-5 +-6 +-7 +-8 +-9 +-0 +- +diff --git a/c b/e +similarity index 90% +rename from c +rename to e +index ec588a0..720a0a7 100644 +--- a/c ++++ b/e +@@ -5,7 +5,6 @@ + 5 + 6 + 7 +-8 + 9 + 0 + diff --git a/testbed/format_examples/replacement.txt b/testbed/format_examples/replacement.txt new file mode 100644 index 000000000..69217008b --- /dev/null +++ b/testbed/format_examples/replacement.txt @@ -0,0 +1,16 @@ +Fifth commit + +@ a starting_line=2 ending_line=11 +3 +4 +5 +6 +0 +10 +11 +12 +@ +@ d - +@ c e +@ c starting_line=8 ending_line=8 +@ diff --git a/tests/code_context_test.py b/tests/code_context_test.py index 4b75bedd3..3eaf2e6a4 100644 --- a/tests/code_context_test.py +++ b/tests/code_context_test.py @@ -279,8 +279,8 @@ async def _count_auto_tokens_where(limit: int) -> int: # Github Actions doesn't have ctags, so we need this if not code_context.settings.no_code_map: # If max_tokens is None, include the full auto-context - assert await _count_auto_tokens_where(None) == 236 # Cmap w/ signatures - assert await _count_auto_tokens_where(230) == 184 # Cmap - assert await _count_auto_tokens_where(170) == 134 # fnames + assert await _count_auto_tokens_where(None) == 253 # Cmap w/ signatures + assert await _count_auto_tokens_where(230) == 201 # Cmap + assert await _count_auto_tokens_where(170) == 151 # fnames # Always return include_files, regardless of max assert await _count_auto_tokens_where(0) == 102 # Include_files only diff --git a/tests/parser_tests/translate_scripts_test.py b/tests/parser_tests/translate_scripts_test.py new file mode 100644 index 000000000..313c0bc56 --- /dev/null +++ b/tests/parser_tests/translate_scripts_test.py @@ -0,0 +1,63 @@ +import pytest + +from mentat.parsers.block_parser import BlockParser +from mentat.parsers.git_parser import GitParser +from mentat.parsers.replacement_parser import ReplacementParser +from scripts.translate_transcript import translate_message + + +@pytest.fixture +def block_format(): + with open("format_examples/block.txt") as f: + return f.read() + + +@pytest.fixture +def replacement_format(): + with open("format_examples/replacement.txt") as f: + return f.read() + + +@pytest.fixture +def git_diff_format(): + with open("format_examples/git_diff.txt") as f: + return f.read() + + +def test_block_to_replacement( + block_format, replacement_format, mock_stream, mock_code_file_manager, mock_git_root +): + assert ( + translate_message(block_format, BlockParser(), ReplacementParser()) + == replacement_format + ) + + +def test_replacement_to_block( + block_format, replacement_format, mock_stream, mock_code_file_manager, mock_git_root +): + assert ( + translate_message(replacement_format, ReplacementParser(), BlockParser()) + == block_format + ) + + +def test_git_diff_to_replacement( + git_diff_format, + replacement_format, + mock_stream, + mock_code_file_manager, + mock_git_root, +): + assert ( + translate_message(git_diff_format, GitParser(), ReplacementParser()) + == replacement_format + ) + + +def test_git_diff_to_block( + git_diff_format, block_format, mock_stream, mock_code_file_manager, mock_git_root +): + assert ( + translate_message(git_diff_format, GitParser(), BlockParser()) == block_format + )