Skip to content

Commit

Permalink
Parser held by config (AbanteAI#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakethekoenig authored Nov 3, 2023
1 parent 45158e9 commit 30feb64
Show file tree
Hide file tree
Showing 15 changed files with 27 additions and 29 deletions.
2 changes: 1 addition & 1 deletion mentat/code_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _get_code_message(self) -> list[str]:
session_context = SESSION_CONTEXT.get()
code_file_manager = session_context.code_file_manager
git_root = session_context.git_root
parser = session_context.parser
parser = session_context.config.parser

code_message: list[str] = []

Expand Down
7 changes: 4 additions & 3 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from attr import validators

from mentat.git_handler import get_shared_git_root_for_paths
from mentat.parsers.parser import Parser
from mentat.parsers.parser_map import parser_map
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import mentat_dir_path
Expand Down Expand Up @@ -53,16 +54,16 @@ class Config:
converter=int_or_none,
validator=validators.optional(validators.ge(0)),
)
format: str = attr.field(
parser: Parser = attr.field( # pyright: ignore
default="block",
metadata={
"description": (
"The format for the LLM to write code in. You probably don't want to"
" mess with this setting."
),
"no_midsession_change": True,
},
validator=validators.optional(validators.in_(parser_map)),
converter=parser_map.get, # pyright: ignore
validator=validators.instance_of(Parser), # pyright: ignore
)

# Context specific settings
Expand Down
4 changes: 2 additions & 2 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def display_token_count(self):
stream = session_context.stream
config = session_context.config
code_context = session_context.code_context
parser = session_context.parser
parser = config.parser

if not is_model_available(config.model):
raise MentatError(
Expand Down Expand Up @@ -158,7 +158,7 @@ async def get_model_response(self) -> list[FileEdit]:
stream = session_context.stream
config = session_context.config
code_context = session_context.code_context
parser = session_context.parser
parser = config.parser
cost_tracker = session_context.cost_tracker

messages_snapshot = self.messages.copy()
Expand Down
6 changes: 1 addition & 5 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from mentat.git_handler import get_shared_git_root_for_paths
from mentat.llm_api import CostTracker, setup_api_key
from mentat.logging_config import setup_logging
from mentat.parsers.parser_map import parser_map
from mentat.session_context import SESSION_CONTEXT, SessionContext
from mentat.session_input import collect_user_input
from mentat.session_stream import SessionStream
Expand Down Expand Up @@ -45,20 +44,17 @@ def __init__(

cost_tracker = CostTracker()

parser = parser_map[config.format]

code_context = CodeContext(stream, git_root, diff, pr_diff)

code_file_manager = CodeFileManager()

conversation = Conversation(parser)
conversation = Conversation(config.parser)

session_context = SessionContext(
stream,
cost_tracker,
git_root,
config,
parser,
code_context,
code_file_manager,
conversation,
Expand Down
2 changes: 0 additions & 2 deletions mentat/session_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from mentat.config import Config
from mentat.conversation import Conversation
from mentat.llm_api import CostTracker
from mentat.parsers.parser import Parser
from mentat.session_stream import SessionStream

SESSION_CONTEXT: ContextVar[SessionContext] = ContextVar("mentat:session_context")
Expand All @@ -24,7 +23,6 @@ class SessionContext:
cost_tracker: CostTracker = attr.field()
git_root: Path = attr.field()
config: Config = attr.field()
parser: Parser = attr.field()
code_context: CodeContext = attr.field()
code_file_manager: CodeFileManager = attr.field()
conversation: Conversation = attr.field()
7 changes: 4 additions & 3 deletions tests/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import mentat.config
from mentat.config import Config, config_file_name
from mentat.parsers.replacement_parser import ReplacementParser


@pytest.mark.asyncio
Expand All @@ -31,7 +32,7 @@ async def test_config_creation():
assert args.model == "model"
assert args.temperature == 0.2
assert args.maximum_context == "1"
assert args.format is None
assert args.parser is None
assert args.use_embeddings
assert args.auto_tokens == "2"

Expand All @@ -46,7 +47,7 @@ async def test_config_creation():
user_config_file.write(dedent("""\
{
"model": "test",
"format": "replacement",
"parser": "replacement",
"input_style": [[ "user", "yes" ]]
}"""))

Expand All @@ -55,7 +56,7 @@ async def test_config_creation():
assert config.model == "model"
assert config.temperature == 0.2
assert config.maximum_context == 1
assert config.format == "replacement"
assert type(config.parser) == ReplacementParser
assert config.use_embeddings
assert not config.no_code_map
assert config.auto_tokens == 2
Expand Down
6 changes: 1 addition & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from mentat.config import Config, config_file_name
from mentat.conversation import Conversation
from mentat.llm_api import CostTracker
from mentat.session import parser_map
from mentat.session_context import SESSION_CONTEXT, SessionContext
from mentat.session_stream import SessionStream, StreamMessage, StreamMessageSource
from mentat.streaming_printer import StreamingPrinter
Expand Down Expand Up @@ -197,20 +196,17 @@ async def _mock_session_context(temp_testbed):

config = Config()

parser = parser_map[config.format]

code_context = CodeContext(stream, git_root)

code_file_manager = CodeFileManager()

conversation = Conversation(parser)
conversation = Conversation(config.parser)

session_context = SessionContext(
stream,
cost_tracker,
git_root,
config,
parser,
code_context,
code_file_manager,
conversation,
Expand Down
3 changes: 2 additions & 1 deletion tests/parser_tests/block_format_error_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import pytest

from mentat.config import Config
from mentat.parsers.block_parser import BlockParser
from mentat.session import Session


@pytest.fixture(autouse=True)
def block_parser(mocker):
mocker.patch.object(Config, "format", new="block")
mocker.patch.object(Config, "parser", new=BlockParser())


temp_file_name = "temp.py"
Expand Down
2 changes: 1 addition & 1 deletion tests/parser_tests/block_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.fixture
def block_parser(mocker):
mocker.patch.object(Config, "format", new="block")
mocker.patch.object(Config, "parser", new=BlockParser())


@pytest.mark.asyncio
Expand Down
3 changes: 2 additions & 1 deletion tests/parser_tests/replacement_format_error_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import pytest

from mentat.config import Config
from mentat.parsers.replacement_parser import ReplacementParser
from mentat.session import Session


@pytest.fixture(autouse=True)
def replacement_parser(mocker):
mocker.patch.object(Config, "format", new="replacement")
mocker.patch.object(Config, "parser", new=ReplacementParser())


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion tests/parser_tests/replacement_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.fixture
def replacement_parser(mocker):
mocker.patch.object(Config, "format", new="replacement")
mocker.patch.object(Config, "parser", new=ReplacementParser())


@pytest.mark.asyncio
Expand Down
3 changes: 2 additions & 1 deletion tests/parser_tests/split_diff_format_error_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import pytest

from mentat.config import Config
from mentat.parsers.split_diff_parser import SplitDiffParser
from mentat.session import Session


@pytest.fixture(autouse=True)
def split_diff_parser(mocker):
mocker.patch.object(Config, "format", new="split-diff")
mocker.patch.object(Config, "parser", new=SplitDiffParser())


@pytest.mark.asyncio
Expand Down
3 changes: 2 additions & 1 deletion tests/parser_tests/split_diff_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import pytest

from mentat.config import Config
from mentat.parsers.split_diff_parser import SplitDiffParser
from mentat.session import Session


@pytest.fixture(autouse=True)
def split_diff_parser(mocker):
mocker.patch.object(Config, "format", new="split-diff")
mocker.patch.object(Config, "parser", new=SplitDiffParser())


@pytest.mark.asyncio
Expand Down
3 changes: 2 additions & 1 deletion tests/parser_tests/unified_diff_format_error_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from pathlib import Path
from textwrap import dedent

from mentat.parsers.unified_diff_parser import UnifiedDiffParser
from mentat.session import Session
from tests.conftest import Config, pytest


@pytest.fixture(autouse=True)
def unified_diff_parser(mocker):
mocker.patch.object(Config, "format", new="unified-diff")
mocker.patch.object(Config, "parser", new=UnifiedDiffParser())


@pytest.mark.asyncio
Expand Down
3 changes: 2 additions & 1 deletion tests/parser_tests/unified_diff_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import pytest

from mentat.config import Config
from mentat.parsers.unified_diff_parser import UnifiedDiffParser
from mentat.session import Session


@pytest.fixture(autouse=True)
def unified_diff_parser(mocker):
mocker.patch.object(Config, "format", new="unified-diff")
mocker.patch.object(Config, "parser", new=UnifiedDiffParser())


@pytest.mark.asyncio
Expand Down

0 comments on commit 30feb64

Please sign in to comment.