From 852e3bf0f7818b402a32cf2905ee462310057cd3 Mon Sep 17 00:00:00 2001 From: Anindyadeep Date: Thu, 16 May 2024 09:55:51 +0000 Subject: [PATCH] fix(dspy): lint and run --- dsp/modules/lm.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index a610d3463..8bed65aa3 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -56,7 +56,11 @@ def inspect_history(self, n: int = 1, skip: int = 0): ): printed.append((prompt, x["response"])) elif provider == "anthropic": - blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"] + blocks = [ + {"text": block.text} + for block in x["response"].content + if block.type == "text" + ] printed.append((prompt, blocks)) elif provider == "cohere": printed.append((prompt, x["response"].text)) @@ -97,16 +101,16 @@ def inspect_history(self, n: int = 1, skip: int = 0): text = choices[0].message.content elif provider == "cloudflare": text = choices[0] - elif provider == "ibm": - text = choices - elif provider == "premai": + elif provider == "ibm" or provider == "premai": text = choices else: text = choices[0]["text"] printing_value += self.print_green(text, end="") if len(choices) > 1 and isinstance(choices, list): - printing_value += self.print_red(f" \t (and {len(choices)-1} other completions)", end="") + printing_value += self.print_red( + f" \t (and {len(choices)-1} other completions)", end="", + ) printing_value += "\n\n\n"