From 6ba9600e1c7325c3eec01266ddbdfad386650c5e Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Wed, 20 Mar 2024 16:20:40 -0700 Subject: [PATCH] Fixed Claude --- dsp/modules/lm.py | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index 5fbf5ec5e..bcb6c3791 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -33,7 +33,8 @@ def print_red(self, text: str, end: str = "\n"): def inspect_history(self, n: int = 1, skip: int = 0): """Prints the last n prompts and their completions. - TODO: print the valid choice that contains filled output field instead of the first + + TODO: print the valid choice that contains filled output field instead of the first. """ provider: str = self.provider @@ -45,23 +46,15 @@ def inspect_history(self, n: int = 1, skip: int = 0): prompt = x["prompt"] if prompt != last_prompt: - - if provider == "clarifai" or provider == "google" or provider == "claude": - printed.append( - ( - prompt, - x['response'], - ), - ) + if provider == "clarifai" or provider == "google": + printed.append((prompt, x["response"])) + elif provider == "anthropic": + blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"] + printed.append((prompt, blocks)) + elif provider == "cohere": + printed.append((prompt, x["response"].generations)) else: - printed.append( - ( - prompt, - x["response"].generations - if provider == "cohere" - else x["response"]["choices"], - ), - ) + printed.append((prompt, x["response"]["choices"])) last_prompt = prompt @@ -79,9 +72,9 @@ def inspect_history(self, n: int = 1, skip: int = 0): if provider == "cohere": text = choices[0].text elif provider == "openai" or provider == "ollama": - text = ' ' + self._get_choice_text(choices[0]).strip() - elif provider == "clarifai" or provider == "claude" : - text=choices + text = " " + self._get_choice_text(choices[0]).strip() + elif provider == "clarifai": + text = choices elif provider == "google": text = choices[0].parts[0].text else: @@ -99,6 +92,6 @@ def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): def copy(self, **kwargs): """Returns a copy of the language model with the same parameters.""" kwargs = {**self.kwargs, **kwargs} - model = kwargs.pop('model') + model = kwargs.pop("model") return self.__class__(model=model, **kwargs)