Skip to content

Commit

Permalink
Small fixes: embeddings and auto-context (AbanteAI#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins authored Jan 1, 2024
1 parent ea55e95 commit cb48a99
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 21 deletions.
5 changes: 4 additions & 1 deletion mentat/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ async def get_feature_similarity_scores(
return [0.0 for _ in checksums]

# Fetch embeddings in batches
n_batches = len(embed_texts) // EMBEDDINGS_API_BATCH_SIZE + 1
if len(embed_texts) == 0:
n_batches = 0
else:
n_batches = len(embed_texts) // EMBEDDINGS_API_BATCH_SIZE + 1
for batch in range(n_batches):
if loading_multiplier:
stream.send(
Expand Down
31 changes: 16 additions & 15 deletions mentat/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,22 @@ def _rp(f: str | Path) -> str:
)
context.update(_rp(f) for f in feature_refs)
# Undo adds/removes/renames to match pre-diff_edit state
file_edits = GitParser().parse_string(diff_edit).file_edits
for file_edit in file_edits:
file_path = _rp(file_edit.file_path)
rename_path = (
""
if not file_edit.rename_file_path
else _rp(file_edit.rename_file_path)
)
if file_edit.is_deletion or rename_path or file_edit.replacements:
context.add(file_path)
if file_edit.is_creation and file_path in context:
context.remove(file_path)
if rename_path and rename_path in context:
context.remove(rename_path)
# TODO: Prompt User to modify/approve context
if diff_edit:
file_edits = GitParser().parse_string(diff_edit).file_edits
for file_edit in file_edits:
file_path = _rp(file_edit.file_path)
rename_path = (
""
if not file_edit.rename_file_path
else _rp(file_edit.rename_file_path)
)
if file_edit.is_deletion or rename_path or file_edit.replacements:
context.add(file_path)
if file_edit.is_creation and file_path in context:
context.remove(file_path)
if rename_path and rename_path in context:
context.remove(rename_path)
# TODO: Prompt User to modify/approve context

sample = Sample(
title=title,
Expand Down
165 changes: 161 additions & 4 deletions scripts/evaluate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import asyncio
import json
import os
from pathlib import Path
from typing import Any
Expand All @@ -16,6 +17,7 @@

from mentat.errors import SampleError
from mentat.git_handler import get_git_diff
from mentat.parsers.git_parser import GitParser
from mentat.python_client.client import PythonClient
from mentat.sampler.sample import Sample
from mentat.sampler.utils import get_active_snapshot_commit
Expand All @@ -33,6 +35,9 @@ def warn(msg: Any):


SAMPLES_DIR = mentat_dir_path / "samples"
os.makedirs(SAMPLES_DIR, exist_ok=True)
FINETUNE_DIR = mentat_dir_path / "finetune"
os.makedirs(FINETUNE_DIR, exist_ok=True)


def apply_diff_to_repo(diff: str, repo: Repo, commit: bool = False) -> str | None:
Expand Down Expand Up @@ -112,9 +117,120 @@ async def evaluate_sample(sample, cwd: Path | str | None = None):
return diff_eval


async def validate_sample(sample, cwd: Path | str | None = None) -> tuple[bool, str]:
"""Validate a sample by applying diffs and checking sample fields."""
try:
required_fields = ["id", "repo", "merge_base", "message_prompt"]
for field in required_fields:
if not getattr(sample, field):
return False, f"Missing required field: {field}"
if not sample.message_edit and not sample.diff_edit:
return False, "Samples must include either diff_edit or message_edit."

# Setup repo
if cwd is None:
cwd = clone_repo(
url=sample.repo,
local_dir_name=sample.repo.split("/")[-1],
refresh=False,
)
if cwd is None:
return False, f"Error cloning repo: {sample.repo}"
else:
cwd = Path(cwd)
os.chdir(cwd)
repo = Repo(".")
repo.head.reset(index=True, working_tree=True) # reset tracked files
repo.git.execute(["git", "clean", "-fd"]) # remove untracked files/directories
repo.git.fetch("--all")
repo.git.checkout(sample.merge_base)
if sample.diff_merge_base:
errors = apply_diff_to_repo(sample.diff_merge_base, repo, commit=True)
if errors:
return False, f"Error applying diff_merge_base: {errors}"
if sample.diff_active:
errors = apply_diff_to_repo(sample.diff_active, repo)
if errors:
return False, f"Error applying diff_active: {errors}"
# TODO: Validate context (paths)
if sample.diff_edit:
errors = apply_diff_to_repo(sample.diff_edit, repo)
if errors:
return False, f"Error applying diff_edit: {errors}"

return True, ""
except Exception as e:
return False, f"Error validating sample: {e}"


async def generate_finetune_gpt(sample, cwd: Path | str | None = None):
"""Generate a fine-tuning example from the sample for GPT-3.5
{"messages": [{"role": "user", "content": "Hello, world!"}, ...]}
"""
# Setup repo, including diff_merge_base and diff_active
if cwd is None:
cwd = clone_repo(
url=sample.repo,
local_dir_name=sample.repo.split("/")[-1],
refresh=False,
)
if cwd is None:
raise SampleError(f"Error cloning {sample.repo}")
else:
cwd = Path(cwd)
os.chdir(cwd)
repo = Repo(".")
repo.head.reset(index=True, working_tree=True) # reset tracked files
repo.git.execute(["git", "clean", "-fd"]) # remove untracked files/directories
repo.git.fetch("--all")
repo.git.checkout(sample.merge_base)
if sample.diff_merge_base:
errors = apply_diff_to_repo(sample.diff_merge_base, repo, commit=True)
if errors:
raise SampleError(f"Error applying diff_merge_base: {errors}")
if sample.diff_active:
errors = apply_diff_to_repo(sample.diff_active, repo)
if errors:
raise SampleError(f"Error applying diff_active: {errors}")

# Run sample in PythonClient
paths = list[Path]()
for a in sample.context:
paths.append(Path(a))
python_client = PythonClient(cwd=cwd, paths=paths)
await python_client.startup()
ctx = SESSION_CONTEXT.get()

# Build the conversation
conversation = list[dict[str, str]]()
if paths:
code_message = await ctx.code_context.get_code_message(0)
conversation.append({"role": "system", "content": code_message})
# TODO: Ignore conversation_history for now because file_edits are not yet included
# conversation += sample.message_history[::-1]
conversation.append({"role": "user", "content": sample.message_prompt})
message_example = sample.message_edit or ""
if sample.diff_edit: # Convert any diff_edit to block format for answer
parsed_llm_response = GitParser().parse_string(sample.diff_edit)
message_example += ctx.config.parser.file_edits_to_llm_message(
parsed_llm_response
)
conversation.append({"role": "system", "content": message_example})

await python_client.shutdown()
return {"messages": conversation}


async def main():
parser = argparse.ArgumentParser(description="Evaluate code samples.")
parser.add_argument("sample_ids", nargs="*", help="Optional sample IDs to evaluate")
parser.add_argument(
"--validate", action="store_true", help="Validate samples instead of evaluating"
)
parser.add_argument(
"--finetune", action="store_true", help="Generate fine-tuning examples"
)
args = parser.parse_args()
sample_files = []
if args.sample_ids:
Expand All @@ -126,9 +242,34 @@ async def main():
print(f"No {'matching ' if args.sample_ids else ''}sample files found.")
return

logs = []
for sample_file in sample_files:
if sample_file.exists():
sample = Sample.load(sample_file)
if not sample_file.exists():
warn(f"Sample file {sample_file} does not exist.")
continue
sample = Sample.load(sample_file)
if args.validate:
is_valid, reason = await validate_sample(sample)
status = (
"\033[92mPASSED\033[0m"
if is_valid
else f"\033[91mFAILED: {reason}\033[0m"
)
print(f"[{sample.id[:8]}] {sample.title}: {status}")
logs.append({"id": sample.id, "is_valid": is_valid, "reason": reason})
elif args.finetune:
print(f"Generating fine-tuning example for sample {sample.id[:8]}")
try:
example = await generate_finetune_gpt(sample)
example_file = FINETUNE_DIR / f"finetune_{sample.id}.json"
with open(example_file, "w") as f:
json.dump(example, f, indent=4)
logs.append({"id": sample.id, "example_file": example_file})
except Exception as e:
warn(
f"Error generating fine-tuning example for sample {sample.id}: {e}"
)
else:
print(f"Evaluating sample {sample.id[:8]}")
print(f" Prompt: {sample.message_prompt}")
diff_eval = await evaluate_sample(sample)
Expand All @@ -140,8 +281,24 @@ async def main():
print(f" Response Grade: {response_grade}")
comparison_grade = await compare_diffs(sample.diff_edit, diff_eval)
print(f" Comparison Grade: {comparison_grade}")
else:
print(f"Sample file {sample_file} does not exist.")
logs.append(
{
"id": sample.id,
"title": sample.title,
"prompt": sample.message_prompt,
"diff_grade": diff_grade,
"response_grade": response_grade,
"comparison_grade": comparison_grade,
}
)

if args.validate:
print(
f"{sum([log['is_valid'] for log in logs])}/{len(logs)} samples passed"
" validation."
)
elif args.finetune:
print(f"{len(logs)} fine-tuning examples generated.")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/benchmarks/benchmarks/mentat/license_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
minimum_context = ["tests/license_check.py:11-22"]

config = Config(
auto_context=True,
auto_context_tokens=True,
maximum_context=8000,
)

Expand Down

0 comments on commit cb48a99

Please sign in to comment.