Skip to content

Commit

Permalink
Merge pull request stanfordnlp#222 from stanfordnlp/knn_fewshot_update
Browse files Browse the repository at this point in the history
updated knn_fewshot with fixes
  • Loading branch information
detaos authored Nov 21, 2023
2 parents 6291f75 + 9a67ba9 commit bfab04c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
12 changes: 5 additions & 7 deletions dsp/modules/sentence_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ def __init__(self) -> None:
def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
pass

def _extract_text_from_examples(self, inp_examples: List["Example"]) -> List[str]:
text_to_vectorize = [
getattr(example, self.field_to_vectorize)
for example in inp_examples
]
return text_to_vectorize
def _extract_text_from_examples(self, inp_examples: List) -> List[str]:
if isinstance(inp_examples[0], str):
return inp_examples
return [" ".join([example[key] for key in example._input_keys]) for example in inp_examples]


class SentenceTransformersVectorizer(BaseSentenceVectorizer):
Expand Down Expand Up @@ -67,7 +65,7 @@ def __init__(
self.vectorize_bs = vectorize_bs
self.normalize_embeddings = normalize_embeddings

def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
def __call__(self, inp_examples: List) -> np.ndarray:
text_to_vectorize = self._extract_text_from_examples(inp_examples)

if self.is_gpu and self.num_devices > 1:
Expand Down
5 changes: 3 additions & 2 deletions dspy/teleprompt/knn_fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dsp

from .teleprompt import Teleprompter
from dspy.teleprompt import BootstrapFewShot

class KNNFewShot(Teleprompter):
def __init__(self, KNN, k: int, trainset: List[dsp.Example]):
Expand All @@ -11,11 +12,11 @@ def __init__(self, KNN, k: int, trainset: List[dsp.Example]):
def compile(self, student, *, teacher=None, trainset, valset=None):
student_copy = student.reset_copy()

def forward_pass(**kwargs):
def forward_pass(*args, **kwargs):
knn_trainset = self.KNN(**kwargs)
few_shot_bootstrap = BootstrapFewShot()
compiled_program = few_shot_bootstrap.compile(student, teacher=teacher, trainset=knn_trainset, valset=valset)
return compiled_program
return compiled_program(**kwargs)

student_copy.forward = types.MethodType(forward_pass, student_copy)
return student_copy

0 comments on commit bfab04c

Please sign in to comment.