Skip to content

Commit

Permalink
Validate main signature fields
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 3, 2024
1 parent dab44cd commit fe321da
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 21 deletions.
48 changes: 27 additions & 21 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,31 +155,37 @@ def forward(self, **kwargs) -> dspy.Prediction:
errors = {}
parsed_results = defaultdict(list)
# Parse the outputs
for name, field in signature.output_fields.items():
for i, completion in enumerate(result.completions):
try:
for i, completion in enumerate(result.completions):
try:
for name, field in signature.output_fields.items():
value = completion[name]
parser = field.json_schema_extra.get("parser", lambda x: x)
completion[name] = parser(value)
parsed_results[name].append(parser(value))
except (pydantic.ValidationError, ValueError) as e:
errors[name] = _format_error(e)
# If we can, we add an example to the error message
current_desc = field.json_schema_extra.get("desc", "")
i = current_desc.find("JSON Schema: ")
if i == -1:
continue # Only add examples to JSON objects
suffix, current_desc = current_desc[i:], current_desc[:i]
prefix = "You MUST use this format: "
if (
try_i + 1 < MAX_RETRIES
and prefix not in current_desc
and (example := self._make_example(field.annotation))
):
signature = signature.with_updated_fields(
name,
desc=current_desc + "\n" + prefix + example + "\n" + suffix,
)
# Instantiate the actual signature with the parsed values.
# This allow pydantic to validate the fields defined in the signature.
_dummy = self.signature(
**kwargs,
**{key: value[i] for key, value in parsed_results.items()},
)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = _format_error(e)
# If we can, we add an example to the error message
current_desc = field.json_schema_extra.get("desc", "")
i = current_desc.find("JSON Schema: ")
if i == -1:
continue # Only add examples to JSON objects
suffix, current_desc = current_desc[i:], current_desc[:i]
prefix = "You MUST use this format: "
if (
try_i + 1 < MAX_RETRIES
and prefix not in current_desc
and (example := self._make_example(field.annotation))
):
signature = signature.with_updated_fields(
name,
desc=current_desc + "\n" + prefix + example + "\n" + suffix,
)
if errors:
# Add new fields for each error
for name, error in errors.items():
Expand Down
17 changes: 17 additions & 0 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,20 @@ def test_parse_type_string():

output = test(input=8, config=dict(n=3)).completions.output
assert output == [0, 1, 2]


def test_fields_on_base_signature():
class SimpleOutput(dspy.Signature):
output: float = dspy.OutputField(gt=0, lt=1)

lm = DummyLM(
[
"2.1", # Bad output
"0.5", # Good output
]
)
dspy.settings.configure(lm=lm)

predictor = TypedPredictor(SimpleOutput)

assert predictor().output == 0.5

0 comments on commit fe321da

Please sign in to comment.