Skip to content

Commit

Permalink
Embeddings (AbanteAI#144)
Browse files Browse the repository at this point in the history
* basic implementation and benchmark
* separate llm_api.raise-if_in_test_environment
  • Loading branch information
granawkins authored Oct 14, 2023
1 parent a9f055e commit 47e75b1
Show file tree
Hide file tree
Showing 15 changed files with 463 additions and 69 deletions.
75 changes: 53 additions & 22 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import os
from contextvars import ContextVar
from pathlib import Path
Expand All @@ -9,10 +8,11 @@

import attr

from .code_file import CodeFile, CodeMessageLevel
from .code_file import CodeFile, CodeMessageLevel, count_feature_tokens
from .code_file_manager import CODE_FILE_MANAGER
from .code_map import check_ctags_disabled
from .diff_context import DiffContext
from .embeddings import get_feature_similarity_scores
from .git_handler import GIT_ROOT, get_non_gitignored_files, get_paths_with_git_diffs
from .include_files import (
build_path_tree,
Expand All @@ -25,23 +25,12 @@
from .utils import sha256


async def _count_tokens_in_features(features: list[CodeFile], model: str) -> int:
sem = asyncio.Semaphore(10)

async def _count_tokens(feature: CodeFile) -> int:
async with sem:
return await feature.count_tokens(model)

tasks = [_count_tokens(f) for f in features]
results = await asyncio.gather(*tasks)
return sum(results)


@attr.define
class CodeContextSettings:
diff: Optional[str] = None
pr_diff: Optional[str] = None
no_code_map: bool = False
use_embedding: bool = False
auto_tokens: Optional[int] = None


Expand Down Expand Up @@ -166,6 +155,7 @@ def _get_code_message_checksum(self, max_tokens: Optional[int] = None) -> str:

async def get_code_message(
self,
prompt: str,
model: str,
max_tokens: int,
) -> str:
Expand All @@ -174,12 +164,13 @@ async def get_code_message(
self._code_message is None
or code_message_checksum != self._code_message_checksum
):
self._code_message = await self._get_code_message(model, max_tokens)
self._code_message = await self._get_code_message(prompt, model, max_tokens)
self._code_message_checksum = self._get_code_message_checksum(max_tokens)
return self._code_message

async def _get_code_message(
self,
prompt: str,
model: str,
max_tokens: int,
) -> str:
Expand All @@ -197,16 +188,17 @@ async def _get_code_message(
code_message += ["Code Files:\n"]

features = self._get_include_features()
include_feature_tokens = await _count_tokens_in_features(
features, model
) - count_tokens("\n".join(code_message), model)
include_feature_tokens = sum(await count_feature_tokens(features, model))
include_feature_tokens -= count_tokens("\n".join(code_message), model)
_max_auto = max(0, max_tokens - include_feature_tokens)
_max_user = self.settings.auto_tokens
if _max_auto == 0 or _max_user == 0:
self.features = features
else:
auto_tokens = _max_auto if _max_user is None else min(_max_auto, _max_user)
self.features = await self._get_auto_features(model, features, auto_tokens)
self.features = await self._get_auto_features(
prompt, model, features, auto_tokens
)

for f in self.features:
code_message += await f.get_code_message()
Expand All @@ -232,15 +224,16 @@ def _feature_relative_path(f: CodeFile) -> str:

async def _get_auto_features(
self,
prompt: str,
model: str,
include_features: list[CodeFile],
max_tokens: int,
) -> list[CodeFile]:
git_root = GIT_ROOT.get()

# Find the first (longest) level that fits
include_features_tokens = await _count_tokens_in_features(
include_features, model
include_features_tokens = sum(
await count_feature_tokens(include_features, model)
)
max_auto_tokens = max_tokens - include_features_tokens
all_features = include_features.copy()
Expand All @@ -263,14 +256,52 @@ async def _get_auto_features(
)
feature = CodeFile(path, level=level, diff=diff_target)
_features.append(feature)
level_length = await _count_tokens_in_features(_features, model)
level_length = sum(await count_feature_tokens(_features, model))
if level_length < max_auto_tokens:
all_features += _features
break

# Sort by relative path
def _feature_relative_path(f: CodeFile) -> str:
return os.path.relpath(f.path, git_root)

all_features = sorted(all_features, key=_feature_relative_path)

# If there's room, convert cmap features to code features (full text)
# starting with the highest-scoring.
cmap_features_tokens = sum(await count_feature_tokens(all_features, model))
max_sim_tokens = max_tokens - cmap_features_tokens
if self.settings.auto_tokens is not None:
max_sim_tokens = min(max_sim_tokens, self.settings.auto_tokens)

if self.settings.use_embedding and max_sim_tokens > 0:
sim_tokens = 0

# Get embedding-similarity scores for all files
all_code_features = [
CodeFile(f.path, CodeMessageLevel.CODE, f.diff)
for f in all_features
if f.path not in self.include_files
]
sim_scores = await get_feature_similarity_scores(prompt, all_code_features)
all_code_features_scored = zip(all_code_features, sim_scores)
all_code_features_sorted = sorted(
all_code_features_scored, key=lambda x: x[1], reverse=True
)
for code_feature, _ in all_code_features_sorted:
# Calculate the total change in length
i_cmap, cmap_feature = next(
(i, f)
for i, f in enumerate(all_features)
if f.path == code_feature.path
)
recovered_tokens = await cmap_feature.count_tokens(model)
new_tokens = await code_feature.count_tokens(model)
forecast = max_sim_tokens - sim_tokens + recovered_tokens - new_tokens
if forecast > 0:
sim_tokens = sim_tokens + new_tokens - recovered_tokens
all_features[i_cmap] = code_feature

return sorted(all_features, key=_feature_relative_path)

def include_file(self, path: Path):
Expand Down
31 changes: 25 additions & 6 deletions mentat/code_file.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
import math
import os
from enum import Enum
from pathlib import Path

from mentat.utils import sha256

from .code_file_manager import CODE_FILE_MANAGER
from .code_map import get_code_map
from .diff_context import annotate_file_message, parse_diff
Expand Down Expand Up @@ -139,19 +142,35 @@ async def _get_code_message(self) -> list[str]:
code_message += section.message
return code_message

_file_checksum: str | None = None
_code_message: list[str] | None = None

async def get_code_message(self) -> list[str]:
def get_checksum(self) -> str:
git_root = GIT_ROOT.get()
code_file_manager = CODE_FILE_MANAGER.get()
abs_path = git_root / self.path
file_checksum = code_file_manager.get_file_checksum(Path(abs_path))
if file_checksum != self._file_checksum or self._code_message is None:
self._file_checksum = file_checksum
return sha256(f"{file_checksum}{self.level.key}{self.diff}")

_feature_checksum: str | None = None
_code_message: list[str] | None = None

async def get_code_message(self) -> list[str]:
feature_checksum = self.get_checksum()
if feature_checksum != self._feature_checksum or self._code_message is None:
self._feature_checksum = feature_checksum
self._code_message = await self._get_code_message()
return self._code_message

async def count_tokens(self, model: str) -> int:
code_message = await self.get_code_message()
return count_tokens("\n".join(code_message), model)


async def count_feature_tokens(features: list[CodeFile], model: str) -> list[int]:
"""Return the number of tokens in each feature."""
sem = asyncio.Semaphore(10)

async def _count_tokens(feature: CodeFile) -> int:
async with sem:
return await feature.count_tokens(model)

tasks = [_count_tokens(f) for f in features]
return await asyncio.gather(*tasks)
3 changes: 2 additions & 1 deletion mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def display_token_count(self):
else:
context_size = maximum_context
tokens = count_tokens(
await code_context.get_code_message(self.model, max_tokens=0),
await code_context.get_code_message("", self.model, max_tokens=0),
self.model,
) + count_tokens(prompt, self.model)

Expand Down Expand Up @@ -151,6 +151,7 @@ async def get_model_response(self) -> list[FileEdit]:
tokens = count_tokens(conversation_history, self.model)
response_buffer = 1000
code_message = await code_context.get_code_message(
messages[-1]["content"],
self.model,
self.max_tokens - tokens - response_buffer,
)
Expand Down
121 changes: 121 additions & 0 deletions mentat/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import gzip
import json
import os
from pathlib import Path

import numpy as np

from .code_file import CodeFile, count_feature_tokens
from .config_manager import mentat_dir_path
from .llm_api import call_embedding_api, count_tokens
from .session_stream import SESSION_STREAM
from .utils import sha256

EMBEDDING_MODEL = "text-embedding-ada-002"
EMBEDDING_MAX_TOKENS = 8192


class EmbeddingsDatabase:
# { sha256 : [ 1536 floats ] }
_dict: dict[str, list[float]] = dict[str, list[float]]()

def __init__(self, output_dir: Path | None = None):
if output_dir is None:
output_dir = mentat_dir_path
os.makedirs(output_dir, exist_ok=True)
self.path = Path(output_dir) / "embeddings.json.gz"
if self.path.exists():
with gzip.open(self.path, "rt") as f:
self._dict = json.load(f)

def save(self):
with gzip.open(self.path, "wt") as f:
json.dump(self._dict, f)

def __getitem__(self, key: str) -> list[float]:
return self._dict[key]

def __setitem__(self, key: str, value: list[float]):
self._dict[key] = value

def __contains__(self, key: str) -> bool:
return key in self._dict


database = EmbeddingsDatabase()


def _batch_ffd(data: dict[str, int], batch_size: int) -> list[list[str]]:
"""Batch files using the First Fit Decreasing algorithm."""
# Sort the data by the length of the strings in descending order
sorted_data = sorted(data.items(), key=lambda x: x[1], reverse=True)
batches = list[list[str]]()
for key, value in sorted_data:
# Place each item in the first batch that it fits in
placed = False
for batch in batches:
if sum(data[k] for k in batch) + value <= batch_size:
batch.append(key)
placed = True
break
if not placed:
batches.append([key])
return batches


def _cosine_similarity(v1: list[float], v2: list[float]) -> float:
"""Calculate the cosine similarity between two vectors."""
dot_product = np.dot(v1, v2)
norm_v1 = np.linalg.norm(v1)
norm_v2 = np.linalg.norm(v2)
return dot_product / (norm_v1 * norm_v2)


async def get_feature_similarity_scores(
prompt: str, features: list[CodeFile]
) -> list[float]:
"""Return the similarity scores for a given prompt and list of features."""
global database
stream = SESSION_STREAM.get()

# Keep things in the same order
checksums: list[str] = [f.get_checksum() for f in features]
tokens: list[int] = await count_feature_tokens(features, EMBEDDING_MODEL)

# Make a checksum:content dict of all items that need to be embedded
items_to_embed = dict[str, str]()
items_to_embed_tokens = dict[str, int]()
prompt_checksum = sha256(prompt)
if prompt_checksum not in database:
items_to_embed[prompt_checksum] = prompt
items_to_embed_tokens[prompt_checksum] = count_tokens(prompt, EMBEDDING_MODEL)
for feature, checksum, token in zip(features, checksums, tokens):
if token > EMBEDDING_MAX_TOKENS:
continue
if checksum not in database:
feature_content = await feature.get_code_message()
# Remove line numbering
items_to_embed[checksum] = "\n".join(feature_content)
items_to_embed_tokens[checksum] = token

# Fetch embeddings in batches
batches = _batch_ffd(items_to_embed_tokens, EMBEDDING_MAX_TOKENS)
for i, batch in enumerate(batches):
batch_content = [items_to_embed[k] for k in batch]
await stream.send(f"Embedding batch {i}/{len(batches)}...")
response = call_embedding_api(batch_content, EMBEDDING_MODEL)
for k, v in zip(batch, response):
database[k] = v
if len(batches) > 0:
database.save()

# Calculate similarity score for each feature
prompt_embedding = database[prompt_checksum]
scores = [0.0 for _ in checksums]
for i, checksum in enumerate(checksums):
if checksum not in database:
continue
feature_embedding = database[checksum]
scores[i] = _cosine_similarity(prompt_embedding, feature_embedding)

return scores
19 changes: 16 additions & 3 deletions mentat/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,17 @@ async def _add_newline(response: AsyncGenerator[Any, None]):
yield {"choices": [{"delta": {"content": "\n"}}]}


async def call_llm_api(
messages: list[dict[str, str]], model: str
) -> AsyncGenerator[Any, None]:
def raise_if_in_test_environment():
if is_test_environment():
logging.critical("OpenAI call attempted in non benchmark test environment!")
raise MentatError("OpenAI call attempted in non benchmark test environment!")


async def call_llm_api(
messages: list[dict[str, str]], model: str
) -> AsyncGenerator[Any, None]:
raise_if_in_test_environment()

response: AsyncGenerator[Any, None] = cast(
AsyncGenerator[Any, None],
await openai.ChatCompletion.acreate( # type: ignore
Expand All @@ -86,6 +90,15 @@ async def call_llm_api(
return _add_newline(response)


def call_embedding_api(
input: list[str], model: str = "text-embedding-ada-002"
) -> list[list[float]]:
raise_if_in_test_environment()

response = openai.Embedding.create(input=input, model=model) # type: ignore
return [i["embedding"] for i in response["data"]] # type: ignore


# Ensures that each chunk will have at most one newline character
def chunk_to_lines(chunk: Any) -> list[str]:
return chunk["choices"][0]["delta"].get("content", "").splitlines(keepends=True)
Expand Down
Loading

0 comments on commit 47e75b1

Please sign in to comment.