Skip to content

Commit

Permalink
Merge pull request stanfordnlp#689 from thomasahle/main
Browse files Browse the repository at this point in the history
Fixing claude
  • Loading branch information
thomasahle authored Mar 20, 2024
2 parents 71dbf30 + 36a72a7 commit 7227e70
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions dsp/modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def print_red(self, text: str, end: str = "\n"):

def inspect_history(self, n: int = 1, skip: int = 0):
"""Prints the last n prompts and their completions.
TODO: print the valid choice that contains filled output field instead of the first
TODO: print the valid choice that contains filled output field instead of the first.
"""
provider: str = self.provider

Expand All @@ -45,23 +46,15 @@ def inspect_history(self, n: int = 1, skip: int = 0):
prompt = x["prompt"]

if prompt != last_prompt:

if provider == "clarifai" or provider == "google" or provider == "claude":
printed.append(
(
prompt,
x['response'],
),
)
if provider == "clarifai" or provider == "google":
printed.append((prompt, x["response"]))
elif provider == "anthropic":
blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"]
printed.append((prompt, blocks))
elif provider == "cohere":
printed.append((prompt, x["response"].generations))
else:
printed.append(
(
prompt,
x["response"].generations
if provider == "cohere"
else x["response"]["choices"],
),
)
printed.append((prompt, x["response"]["choices"]))

last_prompt = prompt

Expand All @@ -79,9 +72,9 @@ def inspect_history(self, n: int = 1, skip: int = 0):
if provider == "cohere":
text = choices[0].text
elif provider == "openai" or provider == "ollama":
text = ' ' + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai" or provider == "claude" :
text=choices
text = " " + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai":
text = choices
elif provider == "google":
text = choices[0].parts[0].text
else:
Expand All @@ -99,6 +92,6 @@ def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
def copy(self, **kwargs):
"""Returns a copy of the language model with the same parameters."""
kwargs = {**self.kwargs, **kwargs}
model = kwargs.pop('model')
model = kwargs.pop("model")

return self.__class__(model=model, **kwargs)

0 comments on commit 7227e70

Please sign in to comment.