Skip to content

Commit

Permalink
Support for explaining errors
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 17, 2024
1 parent aee6baf commit 3d74de7
Showing 1 changed file with 64 additions and 14 deletions.
78 changes: 64 additions & 14 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import inspect
import json
import textwrap
import typing
from typing import Annotated, List, Tuple # noqa: UP035
from typing import Annotated, List, Tuple, Union # noqa: UP035

import pydantic
import ujson
from pydantic.fields import FieldInfo

import dspy
from dsp.templates import passages2text
Expand Down Expand Up @@ -68,7 +70,7 @@ def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802


class TypedPredictor(dspy.Module):
def __init__(self, signature, max_retries=3, wrap_json=False):
def __init__(self, signature, max_retries=3, wrap_json=False, explain_errors=False):
"""Like dspy.Predict, but enforces type annotations in the signature.
Args:
Expand All @@ -81,6 +83,7 @@ def __init__(self, signature, max_retries=3, wrap_json=False):
self.predictor = dspy.Predict(signature)
self.max_retries = max_retries
self.wrap_json = wrap_json
self.explain_errors = explain_errors

def copy(self) -> "TypedPredictor":
return TypedPredictor(self.signature, self.max_retries, self.wrap_json)
Expand Down Expand Up @@ -112,6 +115,55 @@ def _make_example(self, type_) -> str:
# TODO: Instead of using a language model to create the example, we can also just use a
# library like https://pypi.org/project/polyfactory/ that's made exactly to do this.

def _format_error(self, error: Exception, task_description: Union[str, FieldInfo], model_output: str) -> str:
if isinstance(error, pydantic.ValidationError):
errors = []
for e in error.errors():
fields = ", ".join(map(str, e["loc"]))
errors.append(f"{e['msg']}: {fields} (error type: {e['type']})")
error_text = "; ".join(errors)
else:
error_text = repr(error)

if self.explain_errors:
if isinstance(task_description, FieldInfo):
args = task_description.json_schema_extra
task_description = args["prefix"] + " " + args["desc"]
return (
error_text
+ "\n"
+ self._make_explanation(
task_description=task_description,
model_output=model_output,
error=error_text,
)
)

return error_text

def _make_explanation(self, task_description: str, model_output: str, error: str) -> str:
class Signature(dspy.Signature):
__doc__ = textwrap.dedent(
"""
I gave my language model a task, but it failed. Figure out what went wrong,
and write instructions to help it avoid the error next time.""",
)

task_description: str = dspy.InputField(desc="What I asked the model to do")
language_model_output: str = dspy.InputField(desc="The output of the model")
error: str = dspy.InputField(desc="The validation error trigged by the models output")
explanation: str = dspy.OutputField(desc="Explain what the model did wrong")
advice: str = dspy.OutputField(desc="Instructions for the model to do better next time")

# TODO: We could also try repair the output here. For example, if the output is a float, but the
# model returned a "float + explanation", the repair could be to remove the explanation.

return dspy.Predict(Signature)(
task_description=task_description,
language_model_output=model_output,
error=error,
).advice

def _prepare_signature(self) -> dspy.Signature:
"""Add formats and parsers to the signature fields, based on the type annotations of the fields."""
signature = self.signature
Expand Down Expand Up @@ -221,7 +273,8 @@ def forward(self, **kwargs) -> dspy.Prediction:
parser = field.json_schema_extra.get("parser", lambda x: x)
parsed[name] = parser(value)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = _format_error(e)
errors[name] = self._format_error(e, signature.fields[name], value)

# If we can, we add an example to the error message
current_desc = field.json_schema_extra.get("desc", "")
i = current_desc.find("JSON Schema: ")
Expand All @@ -247,7 +300,14 @@ def forward(self, **kwargs) -> dspy.Prediction:
_ = self.signature(**kwargs, **parsed)
parsed_results.append(parsed)
except pydantic.ValidationError as e:
errors["general"] = _format_error(e)
errors["general"] = self._format_error(
e,
signature.instructions,
"\n\n".join(
"> " + field.json_schema_extra["prefix"] + " " + completion[name]
for name, field in signature.output_fields.items()
),
)
if errors:
# Add new fields for each error
for name, error in errors.items():
Expand Down Expand Up @@ -275,16 +335,6 @@ def forward(self, **kwargs) -> dspy.Prediction:
)


def _format_error(error: Exception):
if isinstance(error, pydantic.ValidationError):
errors = []
for e in error.errors():
fields = ", ".join(map(str, e["loc"]))
errors.append(f"{e['msg']}: {fields} (error type: {e['type']})")
return "; ".join(errors)
return repr(error)


def _func_to_signature(func):
"""Make a dspy.Signature based on a function definition."""
sig = inspect.signature(func)
Expand Down

0 comments on commit 3d74de7

Please sign in to comment.