diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index a610d3463..8bed65aa3 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -56,7 +56,11 @@ def inspect_history(self, n: int = 1, skip: int = 0): ): printed.append((prompt, x["response"])) elif provider == "anthropic": - blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"] + blocks = [ + {"text": block.text} + for block in x["response"].content + if block.type == "text" + ] printed.append((prompt, blocks)) elif provider == "cohere": printed.append((prompt, x["response"].text)) @@ -97,16 +101,16 @@ def inspect_history(self, n: int = 1, skip: int = 0): text = choices[0].message.content elif provider == "cloudflare": text = choices[0] - elif provider == "ibm": - text = choices - elif provider == "premai": + elif provider == "ibm" or provider == "premai": text = choices else: text = choices[0]["text"] printing_value += self.print_green(text, end="") if len(choices) > 1 and isinstance(choices, list): - printing_value += self.print_red(f" \t (and {len(choices)-1} other completions)", end="") + printing_value += self.print_red( + f" \t (and {len(choices)-1} other completions)", end="", + ) printing_value += "\n\n\n"