Skip to content

Commit

Permalink
Refactor context benchmark (AbanteAI#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins authored Nov 7, 2023
1 parent f643337 commit 1587431
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 195 deletions.
7 changes: 0 additions & 7 deletions mentat/auto_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ async def select(

class LLMFeatureSelector(FeatureSelector, selector_name="llm"):
feature_selection_prompt_path = Path("feature_selection_prompt.txt")
feature_selection_prompt_training_path = Path(
"feature_selection_prompt_training.txt"
)
feature_selection_response_buffer = 500

async def call_llm_api(self, model: str, messages: list[dict[str, str]]) -> str:
Expand Down Expand Up @@ -122,10 +119,6 @@ async def select(
f"{config.feature_selection_model}"
)
system_prompt = read_prompt(self.feature_selection_prompt_path)
training_prompt = read_prompt(self.feature_selection_prompt_training_path)
system_prompt = system_prompt.format(
training_prompt=training_prompt if expected_edits else ""
)
system_prompt_tokens = count_tokens(
system_prompt, config.feature_selection_model
)
Expand Down
2 changes: 2 additions & 0 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _get_all_features(
diff_context: DiffContext,
code_map: bool,
level: CodeMessageLevel,
max_chars: int = 100000,
) -> list[CodeFeature]:
"""Return a list of all features in the git root with given properties."""

Expand All @@ -48,6 +49,7 @@ def _get_all_features(
abs_path.is_dir()
or not is_file_text_encoded(abs_path)
or abs_path in ignore_files
or os.path.getsize(abs_path) > max_chars
):
continue

Expand Down
41 changes: 10 additions & 31 deletions mentat/code_feature.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import asyncio
import logging
import math
import os
from collections import OrderedDict, defaultdict
from collections import OrderedDict
from enum import Enum
from pathlib import Path
from typing import Optional, Set
from typing import Optional

from mentat.code_map import get_code_map, get_ctags
from mentat.diff_context import annotate_file_message, parse_diff
Expand Down Expand Up @@ -138,8 +139,9 @@ def __init__(

def __repr__(self):
return (
f"CodeFeature(fname={self.path.name}, interval={self.interval},"
f" level={self.level}, diff={self.diff})"
f"CodeFeature(fname={self.path.name},"
f" interval={self.interval.start}-{self.interval.end},"
f" level={self.level.key}, diff={self.diff})"
)

def ref(self):
Expand Down Expand Up @@ -239,7 +241,10 @@ def get_code_message_from_intervals(features: list[CodeFeature]) -> list[str]:
for feature in features_sorted:
starting_line = feature.interval.start
if starting_line < next_line:
raise MentatError("Features overlap")
logging.warning(f"Features overlap: {feature}")
if feature.interval.end < next_line:
continue
feature.interval = Interval(next_line, feature.interval.end)
elif starting_line > next_line:
code_message += ["..."]
code_message += feature.get_code_message()[1:-1]
Expand All @@ -261,29 +266,3 @@ def get_code_message_from_features(features: list[CodeFeature]) -> list[str]:
else:
code_message += get_code_message_from_intervals(path_features)
return code_message


def code_features_difference(
a: list[CodeFeature], b: list[CodeFeature]
) -> dict[Path, Set[int]]:
def _get_lines(features: list[CodeFeature]) -> dict[Path, Set[int]]:
lines: defaultdict[Path, Set[int]] = defaultdict(set)
for feature in features:
if feature.interval.end == math.inf:
n_lines = len(feature.path.read_text().splitlines())
start, end = 1, n_lines + 1
else:
start, end = feature.interval.start, feature.interval.end
for line in range(int(start), int(end)):
lines[feature.path].add(line)
return dict(lines)

a_lines = _get_lines(a)
b_lines = _get_lines(b)
differences: defaultdict[Path, Set[int]] = defaultdict(set)
for file in a_lines:
difference = a_lines[file] - b_lines[file]
if len(difference) > 0:
differences[file] = difference

return dict(differences)
23 changes: 8 additions & 15 deletions mentat/resources/prompts/feature_selection_prompt.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
You are part of an automated coding system.
Below you will see the line 'User Query:', followed by a user query, followed by the line 'Code Files:' and then a pre-selected subset of a codebase.
Your job is to select portions of the code which are relevant to answering that query.
The process after you will read the lines code you select, and then return a plaintext plan of action and a list of Code Edits.
{training_prompt}
Each item in the Code Files below will include a relative path and line numbers.
Identify lines of code which are relevant to the query.
Return a json-serializable list of relevant items following the same format you receive: <rel_path>:<start_line>-<end_line>,<start_line>-<end_line>.
It's important to include lines which would be edited in order to generate the answer as well as lines which are required to understand the context.
It's equally important to exclude irrelevant code, as it has a negative impact on the system performance and cost.
For example: if a question requires creating a new method related to a class, and the method uses an attribute of that
class, include the location for the edit as well as where the attribute is defined. If a typing system is used, include
the type definition as well, and the location of the expected import.
Prefer longer intervals with 5 lines of extra spacing added around target lines.
Make sure your response is valid json, for example:
["path/to/file1.py:1-10,53-60", "path/to/file2.py:10-20"]
Read the User Query, then select and return portions of the Code Files which are relevant to answering it.
Return these sections in a JSON-parsable format following the schema "path:startline-endline", e.g. "[mydir/file_a, mydir/file_b:10-34]".
Here are the steps to follow:
- Understand the User Query thoroughly. Select files and lines that would be edited, added, or deleted in response to the query.
- If an 'Expected Edits' list is provided, ensure all lines impacted by these edits are included in your selection.
- Identify variables and functions that interact with your selected code. If their behavior is crucial to implementing the expected edits, include them in your selection.
- Merge nearby selected sections (<5 lines apart) into larger sections or whole files for context.
- Verify that your response is JSON-parsable.

This file was deleted.

129 changes: 129 additions & 0 deletions scripts/evolve_llm_feature_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse
import asyncio
import json
import logging
from pathlib import Path
from textwrap import dedent
from typing import Any, cast

import openai

from mentat.errors import ModelError
from mentat.prompts.prompts import read_prompt
from tests.benchmarks.context_benchmark import test_code_context_performance

prompts_dir = Path(__file__).parent.parent / "mentat/resources/prompts"


async def evaluate_prompt(prompt: str) -> dict[str, dict[str, float]]:
"""Evaluate feature_selection performance with the given prompt"""
print("\n" + "-" * 80)
print(f"Evaluating prompt: {prompt}")
print("-" * 80 + "\n")
with open(prompts_dir / "feature_selection_prompt.txt", "w") as f:
f.write(prompt)
return await test_code_context_performance([])


# Fitness Function
def recall_weighted_mean(scores: dict[str, dict[str, Any]], weight: int = 3) -> float:
"""Calculate score of prompt based on scores of all features"""
precision_scores = [
s["precision"] for s in scores.values() if s["precision"] is not None
]
recall_scores = [s["recall"] for s in scores.values() if s["recall"] is not None]

recall_score = sum(recall_scores) / len(recall_scores) # mean
precision_mean = sum(precision_scores) / len(precision_scores)
precision_target = 0.8 # "Add an additional 25% of related context"
precision_score = 1 - abs(precision_mean - precision_target)
return (weight * recall_score + precision_score) / (weight + 1)


async def generate_variations(
scores: dict[str, dict[str, float]], population: int
) -> list[dict[str, str]]:
"""Generate variations of the prompt based on the highest-scoring prompts"""
system_prompt = dedent("""\
You are part of an automated coding system, specifically the part of that system that selects code which is related to the user's query. \
The heart of this selection system is a large language model (LLM), like yourself. \
The system will use statistical methods to generate a select a large sample of code for a given user query; \
that code, along with the query, is then sent to the LLM, and the LLM should return the lines/files which are relevant to that query. \
Your job is to help write variations on instructions for the code-selection LLM to maximize its accuracy. \
Below you will see a list of prompts and scores which have already been evaluated. \
Use what you see, together with your creativity, to write {population} new prompts that you think will be even better. \
Feel free to make significant changes to the prompt in your variations - take risks and be creative.
The goals of the task, as hopefully reflected by the scores below, are: \
1. To identify areas of code which would be edited, deleted, or added to in response to a user query. \
2. If an 'Expected Edits' list is provided to the code-selection LLM, it *must* include the lines which are expected to be edited. This is reflected in the scores below as 'Recall'. \
3. To also identify relevant context to the query, such as the type-definitions of variables which will be edited, or functions which would be directly affected by the edits. \
4. To NOT select irrelevant files or lines of code. \
5. It's critical respond to this with a JSON-parsable list of strings (one for each prompt). \
""").format(population=population)
scores = [(prompt, recall_weighted_mean(scores[prompt])) for prompt in scores]
top_scores = sorted(scores, key=lambda x: x[1], reverse=True)[:population]
messages = [
{"role": "system", "content": system_prompt},
{"role": "system", "content": f"Scores: {top_scores}"},
]
response = await openai.ChatCompletion.acreate( # type: ignore
model="gpt-4",
messages=messages,
temperature=0.5,
)
response = cast(str, response["choices"][0]["message"]["content"]) # type: ignore
try:
prompts = json.loads(response)
assert isinstance(prompts, list)
assert all(isinstance(p, str) for p in prompts)
return prompts[:population]
except (json.JSONDecodeError, AssertionError):
logging.error(f"LLM response is not JSON-parsable: {response}")
raise ModelError(f"LLM response is not JSON-parsable: {response}")


async def evolve_prompt(population: int = 5) -> None:
"""Evolve prompt by a single generation"""

scores = {}
scores_path = prompts_dir / "feature_selection_prompt_scores.json"
if scores_path.exists():
with open(scores_path, "r") as f:
scores = json.load(f)

prompts_to_evaluate = []
active_prompt = read_prompt(Path("feature_selection_prompt.txt"))
if active_prompt not in scores: # First time
prompts_to_evaluate.append(active_prompt)
else:
prompts_to_evaluate = await generate_variations(scores, population)

for prompt in prompts_to_evaluate:
score = await evaluate_prompt(prompt)
scores[prompt] = score
with open(scores_path, "w") as f:
json.dump(scores, f)

best_prompt = max(scores, key=lambda k: recall_weighted_mean(scores[k]))
with open(prompts_dir / "feature_selection_prompt.txt", "w") as f:
f.write(str(best_prompt))


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Evolve feature selection prompt based on current benchmarks"
)
parser.add_argument(
"--generations",
default=10,
help="The number of generations to evolve",
)
parser.add_argument(
"--population",
default=5,
help="The number of prompts to evaluate per generation",
)
args = parser.parse_args()
for i in range(args.generations):
print(f"Evolving Generation {i}")
asyncio.run(evolve_prompt(args.population))
73 changes: 20 additions & 53 deletions scripts/git_log_to_transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from git import Repo

from mentat.code_context import CodeContext
from mentat.code_feature import CodeFeature, code_features_difference
from mentat.code_file_manager import CodeFileManager
from mentat.config import Config
from mentat.llm_api import CostTracker, count_tokens, model_context_size
from mentat.llm_api import CostTracker, count_tokens
from mentat.parsers.git_parser import GitParser
from mentat.parsers.replacement_parser import ReplacementParser
from mentat.session_context import SESSION_CONTEXT, SessionContext
from tests.benchmarks.context_benchmark import MockStream, select_features_for_benchmark
from tests.benchmarks.utils import clone_repo

system_prompt = dedent("""\
Expand Down Expand Up @@ -116,7 +115,7 @@ async def translate_commits_to_transcripts(repo, count=10):
code_context = session_context.code_context
config = session_context.config
git_root = session_context.git_root
parser = session_context.parser
parser = config.parser

for commit in repo.iter_commits("HEAD", max_count=count):
try:
Expand Down Expand Up @@ -160,50 +159,25 @@ async def translate_commits_to_transcripts(repo, count=10):
"args": {},
"prompt": prompt,
"expected_edits": llmResponse,
"edited_features": list(
{
str(f.relative_to(git_root))
for f in bound_files(parsedLLMResponse.file_edits, padding=0)
}
),
"selected_features": [],
}
# The longest context that could have been included to generate expected_edits
model = config.model
mentat_prompt_tokens = count_tokens(parser.get_system_prompt(), model)
expected_edits_tokens = count_tokens(benchmark["expected_edits"], model)
max_context_tokens = (
model_context_size(model) - mentat_prompt_tokens - expected_edits_tokens
)
# Fill-in available context
config.auto_tokens = 1e6
config.use_embeddings = True
code_context.use_llm = True
await code_context.get_code_message(
prompt, model, max_context_tokens, benchmark["expected_edits"]
)

# Compare auto-selected context against actual edits
edited_features = list[CodeFeature]()
for file_edit in parsedLLMResponse.file_edits:
if file_edit.is_creation or len(file_edit.replacements) == 0:
continue
path = Path(file_edit.file_path)
for repl in file_edit.replacements:
ref = f"{path}:{repl.starting_line}-{repl.ending_line}"
edited_features.append(CodeFeature(ref, None, None))
missing_context = code_features_difference(
edited_features, code_context.features
)
if missing_context:
print(
"Auto-context missing required files:", repr(dict(missing_context))
try:
result = await select_features_for_benchmark(
session_context,
benchmark,
eval=False,
use_expected=True,
use_llm=True,
)
print("Using edited_files instead")
edited_files = {
str(f.relative_to(git_root))
for f in bound_files(parsedLLMResponse.file_edits, padding=0)
}
benchmark["expected_features"] = list(edited_files)
else:
git_root_length = len(str(git_root)) + 1
selected_features = [
f.ref()[git_root_length:] for f in code_context.features
]
benchmark["expected_features"] = list(selected_features)
benchmark["selected_features"] = result["features"]
except Exception as e:
print(f"Failed to select features: {e}")

benchmarks[sha] = benchmark
except Exception as e:
Expand All @@ -214,12 +188,6 @@ async def translate_commits_to_transcripts(repo, count=10):
return transcripts, benchmarks


class MockStream:
def send(self, message, **kwargs):
end = kwargs.get("end", "\n")
print(message, end=end)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert git commits into transcripts")
parser.add_argument(
Expand Down Expand Up @@ -257,7 +225,6 @@ def send(self, message, **kwargs):
CostTracker(),
Path.cwd(),
config,
ReplacementParser(),
code_context,
CodeFileManager(),
None,
Expand Down
Loading

0 comments on commit 1587431

Please sign in to comment.