diff --git a/mentat/embeddings.py b/mentat/embeddings.py index 5c2b46d88..ff49715aa 100644 --- a/mentat/embeddings.py +++ b/mentat/embeddings.py @@ -135,7 +135,10 @@ async def get_feature_similarity_scores( return [0.0 for _ in checksums] # Fetch embeddings in batches - n_batches = len(embed_texts) // EMBEDDINGS_API_BATCH_SIZE + 1 + if len(embed_texts) == 0: + n_batches = 0 + else: + n_batches = len(embed_texts) // EMBEDDINGS_API_BATCH_SIZE + 1 for batch in range(n_batches): if loading_multiplier: stream.send( diff --git a/mentat/sampler/sampler.py b/mentat/sampler/sampler.py index de5bcd480..b4b70ad0d 100644 --- a/mentat/sampler/sampler.py +++ b/mentat/sampler/sampler.py @@ -170,21 +170,22 @@ def _rp(f: str | Path) -> str: ) context.update(_rp(f) for f in feature_refs) # Undo adds/removes/renames to match pre-diff_edit state - file_edits = GitParser().parse_string(diff_edit).file_edits - for file_edit in file_edits: - file_path = _rp(file_edit.file_path) - rename_path = ( - "" - if not file_edit.rename_file_path - else _rp(file_edit.rename_file_path) - ) - if file_edit.is_deletion or rename_path or file_edit.replacements: - context.add(file_path) - if file_edit.is_creation and file_path in context: - context.remove(file_path) - if rename_path and rename_path in context: - context.remove(rename_path) - # TODO: Prompt User to modify/approve context + if diff_edit: + file_edits = GitParser().parse_string(diff_edit).file_edits + for file_edit in file_edits: + file_path = _rp(file_edit.file_path) + rename_path = ( + "" + if not file_edit.rename_file_path + else _rp(file_edit.rename_file_path) + ) + if file_edit.is_deletion or rename_path or file_edit.replacements: + context.add(file_path) + if file_edit.is_creation and file_path in context: + context.remove(file_path) + if rename_path and rename_path in context: + context.remove(rename_path) + # TODO: Prompt User to modify/approve context sample = Sample( title=title, diff --git a/scripts/evaluate_samples.py b/scripts/evaluate_samples.py index 89e6f92f7..6f516d97a 100644 --- a/scripts/evaluate_samples.py +++ b/scripts/evaluate_samples.py @@ -2,6 +2,7 @@ import argparse import asyncio +import json import os from pathlib import Path from typing import Any @@ -16,6 +17,7 @@ from mentat.errors import SampleError from mentat.git_handler import get_git_diff +from mentat.parsers.git_parser import GitParser from mentat.python_client.client import PythonClient from mentat.sampler.sample import Sample from mentat.sampler.utils import get_active_snapshot_commit @@ -33,6 +35,9 @@ def warn(msg: Any): SAMPLES_DIR = mentat_dir_path / "samples" +os.makedirs(SAMPLES_DIR, exist_ok=True) +FINETUNE_DIR = mentat_dir_path / "finetune" +os.makedirs(FINETUNE_DIR, exist_ok=True) def apply_diff_to_repo(diff: str, repo: Repo, commit: bool = False) -> str | None: @@ -112,9 +117,120 @@ async def evaluate_sample(sample, cwd: Path | str | None = None): return diff_eval +async def validate_sample(sample, cwd: Path | str | None = None) -> tuple[bool, str]: + """Validate a sample by applying diffs and checking sample fields.""" + try: + required_fields = ["id", "repo", "merge_base", "message_prompt"] + for field in required_fields: + if not getattr(sample, field): + return False, f"Missing required field: {field}" + if not sample.message_edit and not sample.diff_edit: + return False, "Samples must include either diff_edit or message_edit." + + # Setup repo + if cwd is None: + cwd = clone_repo( + url=sample.repo, + local_dir_name=sample.repo.split("/")[-1], + refresh=False, + ) + if cwd is None: + return False, f"Error cloning repo: {sample.repo}" + else: + cwd = Path(cwd) + os.chdir(cwd) + repo = Repo(".") + repo.head.reset(index=True, working_tree=True) # reset tracked files + repo.git.execute(["git", "clean", "-fd"]) # remove untracked files/directories + repo.git.fetch("--all") + repo.git.checkout(sample.merge_base) + if sample.diff_merge_base: + errors = apply_diff_to_repo(sample.diff_merge_base, repo, commit=True) + if errors: + return False, f"Error applying diff_merge_base: {errors}" + if sample.diff_active: + errors = apply_diff_to_repo(sample.diff_active, repo) + if errors: + return False, f"Error applying diff_active: {errors}" + # TODO: Validate context (paths) + if sample.diff_edit: + errors = apply_diff_to_repo(sample.diff_edit, repo) + if errors: + return False, f"Error applying diff_edit: {errors}" + + return True, "" + except Exception as e: + return False, f"Error validating sample: {e}" + + +async def generate_finetune_gpt(sample, cwd: Path | str | None = None): + """Generate a fine-tuning example from the sample for GPT-3.5 + + {"messages": [{"role": "user", "content": "Hello, world!"}, ...]} + """ + # Setup repo, including diff_merge_base and diff_active + if cwd is None: + cwd = clone_repo( + url=sample.repo, + local_dir_name=sample.repo.split("/")[-1], + refresh=False, + ) + if cwd is None: + raise SampleError(f"Error cloning {sample.repo}") + else: + cwd = Path(cwd) + os.chdir(cwd) + repo = Repo(".") + repo.head.reset(index=True, working_tree=True) # reset tracked files + repo.git.execute(["git", "clean", "-fd"]) # remove untracked files/directories + repo.git.fetch("--all") + repo.git.checkout(sample.merge_base) + if sample.diff_merge_base: + errors = apply_diff_to_repo(sample.diff_merge_base, repo, commit=True) + if errors: + raise SampleError(f"Error applying diff_merge_base: {errors}") + if sample.diff_active: + errors = apply_diff_to_repo(sample.diff_active, repo) + if errors: + raise SampleError(f"Error applying diff_active: {errors}") + + # Run sample in PythonClient + paths = list[Path]() + for a in sample.context: + paths.append(Path(a)) + python_client = PythonClient(cwd=cwd, paths=paths) + await python_client.startup() + ctx = SESSION_CONTEXT.get() + + # Build the conversation + conversation = list[dict[str, str]]() + if paths: + code_message = await ctx.code_context.get_code_message(0) + conversation.append({"role": "system", "content": code_message}) + # TODO: Ignore conversation_history for now because file_edits are not yet included + # conversation += sample.message_history[::-1] + conversation.append({"role": "user", "content": sample.message_prompt}) + message_example = sample.message_edit or "" + if sample.diff_edit: # Convert any diff_edit to block format for answer + parsed_llm_response = GitParser().parse_string(sample.diff_edit) + message_example += ctx.config.parser.file_edits_to_llm_message( + parsed_llm_response + ) + conversation.append({"role": "system", "content": message_example}) + + await python_client.shutdown() + return {"messages": conversation} + + async def main(): parser = argparse.ArgumentParser(description="Evaluate code samples.") parser.add_argument("sample_ids", nargs="*", help="Optional sample IDs to evaluate") + parser.add_argument( + "--validate", action="store_true", help="Validate samples instead of evaluating" + ) + parser.add_argument( + "--finetune", action="store_true", help="Generate fine-tuning examples" + ) args = parser.parse_args() sample_files = [] if args.sample_ids: @@ -126,9 +242,34 @@ async def main(): print(f"No {'matching ' if args.sample_ids else ''}sample files found.") return + logs = [] for sample_file in sample_files: - if sample_file.exists(): - sample = Sample.load(sample_file) + if not sample_file.exists(): + warn(f"Sample file {sample_file} does not exist.") + continue + sample = Sample.load(sample_file) + if args.validate: + is_valid, reason = await validate_sample(sample) + status = ( + "\033[92mPASSED\033[0m" + if is_valid + else f"\033[91mFAILED: {reason}\033[0m" + ) + print(f"[{sample.id[:8]}] {sample.title}: {status}") + logs.append({"id": sample.id, "is_valid": is_valid, "reason": reason}) + elif args.finetune: + print(f"Generating fine-tuning example for sample {sample.id[:8]}") + try: + example = await generate_finetune_gpt(sample) + example_file = FINETUNE_DIR / f"finetune_{sample.id}.json" + with open(example_file, "w") as f: + json.dump(example, f, indent=4) + logs.append({"id": sample.id, "example_file": example_file}) + except Exception as e: + warn( + f"Error generating fine-tuning example for sample {sample.id}: {e}" + ) + else: print(f"Evaluating sample {sample.id[:8]}") print(f" Prompt: {sample.message_prompt}") diff_eval = await evaluate_sample(sample) @@ -140,8 +281,24 @@ async def main(): print(f" Response Grade: {response_grade}") comparison_grade = await compare_diffs(sample.diff_edit, diff_eval) print(f" Comparison Grade: {comparison_grade}") - else: - print(f"Sample file {sample_file} does not exist.") + logs.append( + { + "id": sample.id, + "title": sample.title, + "prompt": sample.message_prompt, + "diff_grade": diff_grade, + "response_grade": response_grade, + "comparison_grade": comparison_grade, + } + ) + + if args.validate: + print( + f"{sum([log['is_valid'] for log in logs])}/{len(logs)} samples passed" + " validation." + ) + elif args.finetune: + print(f"{len(logs)} fine-tuning examples generated.") if __name__ == "__main__": diff --git a/tests/benchmarks/benchmarks/mentat/license_update.py b/tests/benchmarks/benchmarks/mentat/license_update.py index 7e63f2744..5e8ca50a6 100644 --- a/tests/benchmarks/benchmarks/mentat/license_update.py +++ b/tests/benchmarks/benchmarks/mentat/license_update.py @@ -24,7 +24,7 @@ minimum_context = ["tests/license_check.py:11-22"] config = Config( - auto_context=True, + auto_context_tokens=True, maximum_context=8000, )