Skip to content

Commit

Permalink
Fixed issue with boolean outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 14, 2024
1 parent 259d6b0 commit 982db44
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 1 deletion.
17 changes: 16 additions & 1 deletion dspy/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,22 @@ def _prepare_signature(self) -> dspy.Signature:
is_output = field.json_schema_extra["__dspy_field_type"] == "output"
type_ = field.annotation
if is_output:
if type_ in (str, int, float, bool):
if type_ is bool:

def parse(x):
x = x.strip().lower()
if x not in ("true", "false"):
raise ValueError("Respond with true or false")
return x == "true"

signature = signature.with_updated_fields(
name,
desc=field.json_schema_extra.get("desc", "")
+ (" (Respond with true or false)" if type_ != str else ""),
format=lambda x: x if isinstance(x, str) else str(x),
parser=parse,
)
elif type_ in (str, int, float, bool):
signature = signature.with_updated_fields(
name,
desc=field.json_schema_extra.get("desc", "")
Expand Down
89 changes: 89 additions & 0 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,28 @@ def space_in_a(cls, a: str) -> str:
_ = ValidatedSignature(a="with space")


def test_lm_as_validator():
@predictor
def is_square(n: int) -> bool:
"""Is n a square number?"""

def check_square(n):
assert is_square(n=n)
return n

@predictor
def next_square(n: int) -> Annotated[int, AfterValidator(check_square)]:
"""What is the next square number after n?"""

lm = DummyLM(["3", "False", "4", "True"])
dspy.settings.configure(lm=lm)

m = next_square(n=2)
lm.inspect_history(n=2)

assert m == 4


def test_annotated_validator():
def is_square(n: int) -> int:
root = n**0.5
Expand Down Expand Up @@ -692,3 +714,70 @@ def next_square(n: int) -> Annotated[int, AfterValidator(is_square)]:
lm.inspect_history(n=2)

assert m == 4


def test_demos():
demos = [
dspy.Example(input="What is the speed of light?", output="3e8"),
]
program = LabeledFewShot(k=len(demos)).compile(
student=dspy.TypedPredictor("input -> output"),
trainset=[ex.with_inputs("input") for ex in demos],
)

lm = DummyLM(["Paris"])
dspy.settings.configure(lm=lm)

assert program(input="What is the capital of France?").output == "Paris"

assert lm.get_convo(-1) == textwrap.dedent("""\
Given the fields `input`, produce the fields `output`.
---
Follow the following format.
Input: ${input}
Output: ${output}
---
Input: What is the speed of light?
Output: 3e8
---
Input: What is the capital of France?
Output: Paris""")


def _test_demos_missing_input():
demos = [dspy.Example(input="What is the speed of light?", output="3e8")]
program = LabeledFewShot(k=len(demos)).compile(
student=dspy.TypedPredictor("input -> output, thoughts"),
trainset=[ex.with_inputs("input") for ex in demos],
)
dspy.settings.configure(lm=DummyLM(["My thoughts", "Paris"]))
assert program(input="What is the capital of France?").output == "Paris"

assert dspy.settings.lm.get_convo(-1) == textwrap.dedent("""\
Given the fields `input`, produce the fields `output`.
---
Follow the following format.
Input: ${input}
Thoughts: ${thoughts}
Output: ${output}
---
Input: What is the speed of light?
Output: 3e8
---
Input: What is the capital of France?
Thoughts: My thoughts
Output: Paris""")

0 comments on commit 982db44

Please sign in to comment.