Skip to content

Commit

Permalink
Linting of dummy
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 3, 2024
1 parent 9a1c189 commit a021bd2
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions dspy/utils/dummies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
from dsp.modules import LM
from typing import List, Union, Dict
import numpy as np
from dsp.utils.utils import dotdict
import re
Expand All @@ -9,9 +8,9 @@
class DummyLM(LM):
"""Dummy language model for unit testing purposes."""

def __init__(self, answers: Union[List[str], Dict[str,str]], follow_examples: bool = False):
"""
Initializes the dummy language model.
def __init__(self, answers: list[str] | dict[str, str], follow_examples: bool = False):
"""Initializes the dummy language model.
Parameters:
- answers: A list of strings or a dictionary with string keys and values.
- follow_examples: If True, and the prompt contains an example exactly equal to the prompt,
Expand All @@ -24,7 +23,7 @@ def __init__(self, answers: Union[List[str], Dict[str,str]], follow_examples: bo
self.answers = answers
self.follow_examples = follow_examples

def basic_request(self, prompt, n=1, **kwargs):
def basic_request(self, prompt, n=1, **kwargs) -> dict[str, list[dict[str, str]]]:
"""Generates a dummy response based on the prompt."""
dummy_response = {"choices": []}
for _ in range(n):
Expand Down Expand Up @@ -55,12 +54,14 @@ def basic_request(self, prompt, n=1, **kwargs):
answer = "No more responses"

# Mimic the structure of a real language model response.
dummy_response["choices"].append({
"text": answer,
"finish_reason": "simulated completion",
})

RED, GREEN, RESET = '\033[91m', '\033[92m', '\033[0m'
dummy_response["choices"].append(
{
"text": answer,
"finish_reason": "simulated completion",
},
)

RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m"
print("=== DummyLM ===")
print(prompt, end="")
print(f"{RED}{answer}{RESET}")
Expand All @@ -77,67 +78,68 @@ def basic_request(self, prompt, n=1, **kwargs):

return dummy_response

def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
def __call__(self, prompt, _only_completed=True, _return_sorted=False, **kwargs):
"""Retrieves dummy completions."""
response = self.basic_request(prompt, **kwargs)
choices = response["choices"]

# Filter choices and return text completions.
completions = [choice["text"] for choice in choices]

return completions
return [choice["text"] for choice in choices]

def get_convo(self, index):
"""Get the prompt + anwer from the ith message"""
return self.history[index]['prompt'] \
+ " " \
+ self.history[index]['response']['choices'][0]['text']
def get_convo(self, index) -> str:
"""Get the prompt + anwer from the ith message."""
return self.history[index]["prompt"] + " " + self.history[index]["response"]["choices"][0]["text"]


def dummy_rm(passages=()):
def dummy_rm(passages=()) -> callable:
if not passages:
def inner(query:str, *, k:int, **kwargs):

def inner(query: str, *, k: int, **kwargs):
assert False, "No passages defined"

return inner
max_length = max(map(len, passages)) + 100
vectorizer = DummyVectorizer(max_length)
passage_vecs = vectorizer(passages)
def inner(query:str, *, k:int, **kwargs):

def inner(query: str, *, k: int, **kwargs):
assert k <= len(passages)
query_vec = vectorizer([query])[0]
scores = passage_vecs @ query_vec
largest_idx = (-scores).argsort()[:k]
#return dspy.Prediction(passages=[passages[i] for i in largest_idx])
# return dspy.Prediction(passages=[passages[i] for i in largest_idx])
return [dotdict(dict(long_text=passages[i])) for i in largest_idx]

return inner


class DummyVectorizer:
"""Simple vectorizer based on n-grams"""
"""Simple vectorizer based on n-grams."""

def __init__(self, max_length=100, n_gram=2):
self.max_length = max_length
self.n_gram = n_gram
self.P = 10**9 + 7 # A large prime number
random.seed(123)
self.coeffs = [random.randrange(1, self.P) for _ in range(n_gram)]

def _hash(self, gram):
"""Hashes a string using a polynomial hash function"""
"""Hashes a string using a polynomial hash function."""
h = 1
for coeff, c in zip(self.coeffs, gram):
h = h * coeff + ord(c)
h %= self.P
return h % self.max_length

def __call__(self, texts: List[str]) -> np.ndarray:
def __call__(self, texts: list[str]) -> np.ndarray:
vecs = []
for text in texts:
grams = [text[i:i+self.n_gram] for i in range(len(text) - self.n_gram + 1)]
grams = [text[i : i + self.n_gram] for i in range(len(text) - self.n_gram + 1)]
vec = [0] * self.max_length
for gram in grams:
vec[self._hash(gram)] += 1
vecs.append(vec)

vecs = np.array(vecs, dtype=np.float32)
vecs -= np.mean(vecs, axis=1, keepdims=True)
vecs /= np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-10 # Added epsilon to avoid division by zero
Expand Down

0 comments on commit a021bd2

Please sign in to comment.