From 982db444322cbf6a545a2504ce162b15dafee975 Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Thu, 14 Mar 2024 15:35:25 -0700 Subject: [PATCH] Fixed issue with boolean outputs --- dspy/functional/functional.py | 17 +++++- tests/functional/test_functional.py | 89 +++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index 048cd9c1a..dab08eb43 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -119,7 +119,22 @@ def _prepare_signature(self) -> dspy.Signature: is_output = field.json_schema_extra["__dspy_field_type"] == "output" type_ = field.annotation if is_output: - if type_ in (str, int, float, bool): + if type_ is bool: + + def parse(x): + x = x.strip().lower() + if x not in ("true", "false"): + raise ValueError("Respond with true or false") + return x == "true" + + signature = signature.with_updated_fields( + name, + desc=field.json_schema_extra.get("desc", "") + + (" (Respond with true or false)" if type_ != str else ""), + format=lambda x: x if isinstance(x, str) else str(x), + parser=parse, + ) + elif type_ in (str, int, float, bool): signature = signature.with_updated_fields( name, desc=field.json_schema_extra.get("desc", "") diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index f0e22e4be..2bedfa356 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -653,6 +653,28 @@ def space_in_a(cls, a: str) -> str: _ = ValidatedSignature(a="with space") +def test_lm_as_validator(): + @predictor + def is_square(n: int) -> bool: + """Is n a square number?""" + + def check_square(n): + assert is_square(n=n) + return n + + @predictor + def next_square(n: int) -> Annotated[int, AfterValidator(check_square)]: + """What is the next square number after n?""" + + lm = DummyLM(["3", "False", "4", "True"]) + dspy.settings.configure(lm=lm) + + m = next_square(n=2) + lm.inspect_history(n=2) + + assert m == 4 + + def test_annotated_validator(): def is_square(n: int) -> int: root = n**0.5 @@ -692,3 +714,70 @@ def next_square(n: int) -> Annotated[int, AfterValidator(is_square)]: lm.inspect_history(n=2) assert m == 4 + + +def test_demos(): + demos = [ + dspy.Example(input="What is the speed of light?", output="3e8"), + ] + program = LabeledFewShot(k=len(demos)).compile( + student=dspy.TypedPredictor("input -> output"), + trainset=[ex.with_inputs("input") for ex in demos], + ) + + lm = DummyLM(["Paris"]) + dspy.settings.configure(lm=lm) + + assert program(input="What is the capital of France?").output == "Paris" + + assert lm.get_convo(-1) == textwrap.dedent("""\ + Given the fields `input`, produce the fields `output`. + + --- + + Follow the following format. + + Input: ${input} + Output: ${output} + + --- + + Input: What is the speed of light? + Output: 3e8 + + --- + + Input: What is the capital of France? + Output: Paris""") + + +def _test_demos_missing_input(): + demos = [dspy.Example(input="What is the speed of light?", output="3e8")] + program = LabeledFewShot(k=len(demos)).compile( + student=dspy.TypedPredictor("input -> output, thoughts"), + trainset=[ex.with_inputs("input") for ex in demos], + ) + dspy.settings.configure(lm=DummyLM(["My thoughts", "Paris"])) + assert program(input="What is the capital of France?").output == "Paris" + + assert dspy.settings.lm.get_convo(-1) == textwrap.dedent("""\ + Given the fields `input`, produce the fields `output`. + + --- + + Follow the following format. + + Input: ${input} + Thoughts: ${thoughts} + Output: ${output} + + --- + + Input: What is the speed of light? + Output: 3e8 + + --- + + Input: What is the capital of France? + Thoughts: My thoughts + Output: Paris""")