Skip to content

Commit

Permalink
Config Command (AbanteAI#221)
Browse files Browse the repository at this point in the history
* Config Command

* You can set or get a config with `/config`
* I realized conversation was holding onto its own copy of model
  preventing it being changed. It now gets it from the config as
  necessary.
  (Perhaps we should go all the way and have the llm_api be responsible
  for choosing the model? I could imagine wanting to keep track of which
  model generated a message for the future though).
* I updated the use_embeddings comment so users know they don't need to
  restart mentat.
* Note attr is taking care of type checking so we can't end up with
  illegal config properties after a call to `/config`
* Validators added

We check:
* Temperature in [0,1]
* Format exists
* contexts are positive

If type or value errors are encountered while changing a setting this is
reported gracefully to the user if in config, by command or by flag.
  • Loading branch information
jakethekoenig authored Oct 26, 2023
1 parent c2bf733 commit 6f27ae6
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 56 deletions.
2 changes: 1 addition & 1 deletion mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ async def search(

if not config.use_embeddings:
stream.send(
"Embeddings are disabled. To enable, restart with '--use-embeddings'.",
"Embeddings are disabled. Enable with `/config use_embeddings true`",
color="light_red",
)
return []
Expand Down
53 changes: 50 additions & 3 deletions mentat/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pathlib import Path
from typing import List

import attr

from mentat.conversation import MessageRole
from mentat.errors import MentatError, UserError
from mentat.git_handler import commit
Expand Down Expand Up @@ -147,7 +149,7 @@ async def apply(self, *args: str) -> None:
git_root = session_context.git_root

if len(args) == 0:
stream.send("No files specified\n", color="yellow")
stream.send("No files specified", color="yellow")
return
for file_path in args:
included_paths, invalid_paths = code_context.include_file(
Expand Down Expand Up @@ -176,7 +178,7 @@ async def apply(self, *args: str) -> None:
git_root = session_context.git_root

if len(args) == 0:
stream.send("No files specified\n", color="yellow")
stream.send("No files specified", color="yellow")
return
for file_path in args:
excluded_paths, invalid_paths = code_context.exclude_file(
Expand Down Expand Up @@ -272,7 +274,7 @@ async def apply(self, *args: str) -> None:
git_root = session_context.git_root

if len(args) == 0:
stream.send("No search query specified\n", color="yellow")
stream.send("No search query specified", color="yellow")
return
try:
query = " ".join(args)
Expand Down Expand Up @@ -338,3 +340,48 @@ def argument_names(cls) -> list[str]:
@classmethod
def help_message(cls) -> str:
return "Shows all files currently in Mentat's context"


class ConfigCommand(Command, command_name="config"):
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
config = session_context.config
if len(args) == 0:
stream.send("No config option specified", color="yellow")
else:
setting = args[0]
if hasattr(config, setting):
if len(args) == 1:
value = getattr(config, setting)
description = attr.fields_dict(type(config))[setting].metadata.get(
"description"
)
stream.send(f"{setting}: {value}")
if description:
stream.send(f"Description: {description}")
elif len(args) == 2:
value = args[1]
if attr.fields_dict(type(config))[setting].metadata.get(
"no_midsession_change"
):
stream.send(
f"Cannot change {setting} mid-session. Please restart"
" Mentat to change this setting.",
color="yellow",
)
return
setattr(config, setting, value)
stream.send(f"{setting} set to {value}", color="green")
else:
stream.send("Too many arguments", color="yellow")
else:
stream.send(f"Unrecognized config option: {setting}", color="red")

@classmethod
def argument_names(cls) -> list[str]:
return ["setting", "value"]

@classmethod
def help_message(cls) -> str:
return "Set a configuration option or omit value to see current value."
39 changes: 33 additions & 6 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from pathlib import Path

import attr
from attr import validators

from mentat.git_handler import get_shared_git_root_for_paths
from mentat.parsers.parser_map import parser_map
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import mentat_dir_path

Expand Down Expand Up @@ -35,7 +37,10 @@ class Config:

# Model specific settings
model: str = attr.field(default="gpt-4-0314")
temperature: float = attr.field(default=0.5, converter=float)
temperature: float = attr.field(
default=0.5, converter=float, validator=[validators.le(1), validators.ge(0)]
)

maximum_context: int | None = attr.field(
default=None,
metadata={
Expand All @@ -46,8 +51,19 @@ class Config:
)
},
converter=int_or_none,
validator=validators.optional(validators.ge(0)),
)
format: str = attr.field(
default="block",
metadata={
"description": (
"The format for the LLM to write code in. You probably don't want to"
" mess with this setting."
),
"no_midsession_change": True,
},
validator=validators.optional(validators.in_(parser_map)),
)
format: str = attr.field(default="block")

# Context specific settings
file_exclude_glob_list: list[str] = attr.field(
Expand Down Expand Up @@ -80,6 +96,7 @@ class Config:
"abbreviation": "a",
},
converter=int_or_none,
validator=validators.optional(validators.ge(0)),
)

# Only settable by config file
Expand All @@ -91,14 +108,15 @@ class Config:
],
metadata={
"description": "Styling information for the terminal.",
"only_config_file": True,
"no_flag": True,
"no_midsession_change": True,
},
)

@classmethod
def add_fields_to_argparse(cls, parser: ArgumentParser) -> None:
for field in attr.fields(Config):
if "only_config_file" in field.metadata:
if "no_flag" in field.metadata:
continue
name = [f"--{field.name.replace('_', '-')}"]
if "abbreviation" in field.metadata:
Expand Down Expand Up @@ -138,7 +156,10 @@ def create(cls, args: Namespace | None = None) -> Config:
def load_namespace(self, args: Namespace) -> None:
for field in vars(args):
if hasattr(self, field) and getattr(args, field) is not None:
setattr(self, field, getattr(args, field))
try:
setattr(self, field, getattr(args, field))
except (ValueError, TypeError) as e:
self.error(f"Warning: Illegal value for {field}: {e}")

def load_file(self, path: Path) -> None:
if path.exists():
Expand All @@ -153,7 +174,13 @@ def load_file(self, path: Path) -> None:
return
for field in config:
if hasattr(self, field):
setattr(self, field, config[field])
try:
setattr(self, field, config[field])
except (ValueError, TypeError) as e:
self.error(
f"Warning: Config {path} contains invalid value for"
f" setting: {field}\n{e}"
)
else:
self.error(
f"Warning: Config {path} contains unrecognized setting: {field}"
Expand Down
39 changes: 20 additions & 19 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from openai.error import InvalidRequestError

from mentat.config import Config, user_config_path
from mentat.errors import MentatError
from mentat.llm_api import (
call_llm_api,
Expand All @@ -34,8 +33,7 @@ class MessageRole(Enum):
class Conversation:
max_tokens: int

def __init__(self, config: Config, parser: Parser):
self.model = config.model
def __init__(self, parser: Parser):
self.messages = list[dict[str, str]]()

# This contain the messages the user actually sends and the messages the model output
Expand All @@ -52,41 +50,41 @@ async def display_token_count(self):
code_context = session_context.code_context
parser = session_context.parser

if not is_model_available(self.model):
if not is_model_available(config.model):
raise MentatError(
f"Model {self.model} is not available. Please try again with a"
f"Model {config.model} is not available. Please try again with a"
" different model."
)
if "gpt-4" not in self.model:
if "gpt-4" not in config.model:
stream.send(
"Warning: Mentat has only been tested on GPT-4. You may experience"
" issues with quality. This model may not be able to respond in"
" mentat's edit format.",
color="yellow",
)
if "gpt-3.5" not in self.model:
if "gpt-3.5" not in config.model:
stream.send(
"Warning: Mentat does not know how to calculate costs or context"
" size for this model.",
color="yellow",
)
prompt = parser.get_system_prompt()
context_size = model_context_size(self.model)
context_size = model_context_size(config.model)
maximum_context = config.maximum_context
if maximum_context:
if context_size:
context_size = min(context_size, maximum_context)
else:
context_size = maximum_context
tokens = count_tokens(
await code_context.get_code_message("", self.model, max_tokens=0),
self.model,
) + count_tokens(prompt, self.model)
await code_context.get_code_message("", config.model, max_tokens=0),
config.model,
) + count_tokens(prompt, config.model)

if not context_size:
raise MentatError(
f"Context size for {self.model} is not known. Please set"
f" maximum-context in {user_config_path}."
f"Context size for {config.model} is not known. Please set"
" maximum-context with `/config maximum_context value`."
)
else:
self.max_tokens = context_size
Expand Down Expand Up @@ -136,8 +134,10 @@ async def _stream_model_response(
messages: list[dict[str, str]],
):
start_time = default_timer()
session_context = SESSION_CONTEXT.get()
config = session_context.config
try:
response = await call_llm_api(messages, self.model)
response = await call_llm_api(messages, config.model)
stream.send(
"Streaming... use control-c to interrupt the model at any point\n"
)
Expand All @@ -156,6 +156,7 @@ async def _stream_model_response(
async def get_model_response(self) -> list[FileEdit]:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
config = session_context.config
code_context = session_context.code_context
parser = session_context.parser
cost_tracker = session_context.cost_tracker
Expand All @@ -164,24 +165,24 @@ async def get_model_response(self) -> list[FileEdit]:

# Rebuild code context with active code and available tokens
conversation_history = "\n".join([m["content"] for m in messages_snapshot])
tokens = count_tokens(conversation_history, self.model)
tokens = count_tokens(conversation_history, config.model)
response_buffer = 1000
code_message = await code_context.get_code_message(
messages_snapshot[-1]["content"],
self.model,
config.model,
self.max_tokens - tokens - response_buffer,
)
messages_snapshot.append({"role": "system", "content": code_message})

code_context.display_features()
num_prompt_tokens = get_prompt_token_count(messages_snapshot, self.model)
num_prompt_tokens = get_prompt_token_count(messages_snapshot, config.model)
parsedLLMResponse, time_elapsed = await self._stream_model_response(
stream, parser, messages_snapshot
)
cost_tracker.display_api_call_stats(
num_prompt_tokens,
count_tokens(parsedLLMResponse.full_response, self.model),
self.model,
count_tokens(parsedLLMResponse.full_response, config.model),
config.model,
time_elapsed,
)

Expand Down
12 changes: 12 additions & 0 deletions mentat/parsers/parser_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from mentat.parsers.block_parser import BlockParser
from mentat.parsers.parser import Parser
from mentat.parsers.replacement_parser import ReplacementParser
from mentat.parsers.split_diff_parser import SplitDiffParser
from mentat.parsers.unified_diff_parser import UnifiedDiffParser

parser_map: dict[str, Parser] = {
"block": BlockParser(),
"replacement": ReplacementParser(),
"split-diff": SplitDiffParser(),
"unified-diff": UnifiedDiffParser(),
}
15 changes: 2 additions & 13 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,11 @@
from mentat.git_handler import get_shared_git_root_for_paths
from mentat.llm_api import CostTracker, setup_api_key
from mentat.logging_config import setup_logging
from mentat.parsers.block_parser import BlockParser
from mentat.parsers.parser import Parser
from mentat.parsers.replacement_parser import ReplacementParser
from mentat.parsers.split_diff_parser import SplitDiffParser
from mentat.parsers.unified_diff_parser import UnifiedDiffParser
from mentat.parsers.parser_map import parser_map
from mentat.session_context import SESSION_CONTEXT, SessionContext
from mentat.session_input import collect_user_input
from mentat.session_stream import SessionStream

parser_map: dict[str, Parser] = {
"block": BlockParser(),
"replacement": ReplacementParser(),
"split-diff": SplitDiffParser(),
"unified-diff": UnifiedDiffParser(),
}


class Session:
def __init__(
Expand Down Expand Up @@ -62,7 +51,7 @@ def __init__(

code_file_manager = CodeFileManager()

conversation = Conversation(config, parser)
conversation = Conversation(parser)

session_context = SessionContext(
stream,
Expand Down
13 changes: 2 additions & 11 deletions scripts/translate_transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,13 @@
from unittest.mock import AsyncMock

from mentat.code_file_manager import CodeFileManager
from mentat.parsers.block_parser import BlockParser
from mentat.parsers.git_parser import GitParser
from mentat.parsers.parser import Parser
from mentat.parsers.replacement_parser import ReplacementParser
from mentat.parsers.split_diff_parser import SplitDiffParser
from mentat.parsers.unified_diff_parser import UnifiedDiffParser
from mentat.parsers.parser_map import parser_map
from mentat.session_context import SESSION_CONTEXT, SessionContext
from mentat.utils import convert_string_to_asyncgen

parser_map: dict[str, Parser | GitParser] = {
"block": BlockParser(),
"replacement": ReplacementParser(),
"split-diff": SplitDiffParser(),
"unified-diff": UnifiedDiffParser(),
"git": GitParser(),
}
parser_map["git"] = GitParser()


def translate_message(
Expand Down
Loading

0 comments on commit 6f27ae6

Please sign in to comment.