Skip to content

Commit

Permalink
Merge pull request stanfordnlp#673 from thomasahle/main
Browse files Browse the repository at this point in the history
TypedPredictor improvements
  • Loading branch information
thomasahle authored Mar 18, 2024
2 parents eb2dd73 + 0adcc8c commit 3c2edf9
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 37 deletions.
100 changes: 82 additions & 18 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import inspect
import json
import textwrap
import typing
from typing import Annotated, List, Tuple # noqa: UP035
from typing import Annotated, List, Tuple, Union # noqa: UP035

import pydantic
import ujson
from pydantic.fields import FieldInfo

import dspy
from dsp.templates import passages2text
Expand Down Expand Up @@ -51,9 +53,9 @@ def __init__(self):
self.__dict__[name] = attr.copy()


def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802
def TypedChainOfThought(signature, instructions=None, *, max_retries=3) -> dspy.Module: # noqa: N802
"""Just like TypedPredictor, but adds a ChainOfThought OutputField."""
signature = ensure_signature(signature)
signature = ensure_signature(signature, instructions)
output_keys = ", ".join(signature.output_fields.keys())
return TypedPredictor(
signature.prepend(
Expand All @@ -68,7 +70,7 @@ def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802


class TypedPredictor(dspy.Module):
def __init__(self, signature, max_retries=3, wrap_json=False):
def __init__(self, signature, instructions=None, *, max_retries=3, wrap_json=False, explain_errors=False):
"""Like dspy.Predict, but enforces type annotations in the signature.
Args:
Expand All @@ -77,13 +79,14 @@ def __init__(self, signature, max_retries=3, wrap_json=False):
wrap_json: If True, json objects in the input will be wrapped in ```json ... ```
"""
super().__init__()
self.signature = ensure_signature(signature)
self.signature = ensure_signature(signature, instructions)
self.predictor = dspy.Predict(signature)
self.max_retries = max_retries
self.wrap_json = wrap_json
self.explain_errors = explain_errors

def copy(self) -> "TypedPredictor":
return TypedPredictor(self.signature, self.max_retries, self.wrap_json)
return TypedPredictor(self.signature, max_retries=self.max_retries, wrap_json=self.wrap_json)

def __repr__(self):
"""Return a string representation of the TypedPredictor object."""
Expand Down Expand Up @@ -112,6 +115,63 @@ def _make_example(self, type_) -> str:
# TODO: Instead of using a language model to create the example, we can also just use a
# library like https://pypi.org/project/polyfactory/ that's made exactly to do this.

def _format_error(
self,
error: Exception,
task_description: Union[str, FieldInfo],
model_output: str,
lm_explain: bool,
) -> str:
if isinstance(error, pydantic.ValidationError):
errors = []
for e in error.errors():
fields = ", ".join(map(str, e["loc"]))
errors.append(f"{e['msg']}: {fields} (error type: {e['type']})")
error_text = "; ".join(errors)
else:
error_text = repr(error)

if self.explain_errors and lm_explain:
if isinstance(task_description, FieldInfo):
args = task_description.json_schema_extra
task_description = args["prefix"] + " " + args["desc"]
return (
error_text
+ "\n"
+ self._make_explanation(
task_description=task_description,
model_output=model_output,
error=error_text,
)
)

return error_text

def _make_explanation(self, task_description: str, model_output: str, error: str) -> str:
class Signature(dspy.Signature):
__doc__ = textwrap.dedent(
"""
I gave my language model a task, but it failed. Figure out what went wrong,
and write instructions to help it avoid the error next time.""",
)

task_description: str = dspy.InputField(desc="What I asked the model to do")
language_model_output: str = dspy.InputField(desc="The output of the model")
error: str = dspy.InputField(desc="The validation error trigged by the models output")
explanation: str = dspy.OutputField(desc="Explain what the model did wrong")
advice: str = dspy.OutputField(
desc="Instructions for the model to do better next time. A single paragraph.",
)

# TODO: We could also try repair the output here. For example, if the output is a float, but the
# model returned a "float + explanation", the repair could be to remove the explanation.

return dspy.Predict(Signature)(
task_description=task_description,
language_model_output=model_output,
error=error,
).advice

def _prepare_signature(self) -> dspy.Signature:
"""Add formats and parsers to the signature fields, based on the type annotations of the fields."""
signature = self.signature
Expand Down Expand Up @@ -221,7 +281,13 @@ def forward(self, **kwargs) -> dspy.Prediction:
parser = field.json_schema_extra.get("parser", lambda x: x)
parsed[name] = parser(value)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = _format_error(e)
errors[name] = self._format_error(
e,
signature.fields[name],
value,
lm_explain=try_i + 1 < self.max_retries,
)

# If we can, we add an example to the error message
current_desc = field.json_schema_extra.get("desc", "")
i = current_desc.find("JSON Schema: ")
Expand All @@ -247,7 +313,15 @@ def forward(self, **kwargs) -> dspy.Prediction:
_ = self.signature(**kwargs, **parsed)
parsed_results.append(parsed)
except pydantic.ValidationError as e:
errors["general"] = _format_error(e)
errors["general"] = self._format_error(
e,
signature.instructions,
"\n\n".join(
"> " + field.json_schema_extra["prefix"] + " " + completion[name]
for name, field in signature.output_fields.items()
),
lm_explain=try_i + 1 < self.max_retries,
)
if errors:
# Add new fields for each error
for name, error in errors.items():
Expand Down Expand Up @@ -275,16 +349,6 @@ def forward(self, **kwargs) -> dspy.Prediction:
)


def _format_error(error: Exception):
if isinstance(error, pydantic.ValidationError):
errors = []
for e in error.errors():
fields = ", ".join(map(str, e["loc"]))
errors.append(f"{e['msg']}: {fields} (error type: {e['type']})")
return "; ".join(errors)
return repr(error)


def _func_to_signature(func):
"""Make a dspy.Signature based on a function definition."""
sig = inspect.signature(func)
Expand Down
42 changes: 23 additions & 19 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,13 @@ class Signature(BaseModel, metaclass=SignatureMeta):
pass


def ensure_signature(signature: Union[str, Type[Signature]]) -> Signature:
def ensure_signature(signature: Union[str, Type[Signature]], instructions=None) -> Signature:
if signature is None:
return None
if isinstance(signature, str):
return Signature(signature)
return Signature(signature, instructions)
if instructions is not None:
raise ValueError("Don't specify instructions when initializing with a Signature")
return signature


Expand Down Expand Up @@ -277,27 +279,22 @@ def _parse_signature(signature: str) -> Tuple[Type, Field]:
if signature.count("->") != 1:
raise ValueError(f"Invalid signature format: '{signature}', must contain exactly one '->'.")

inputs_str, outputs_str = signature.split("->")

fields = {}
inputs_str, outputs_str = map(str.strip, signature.split("->"))
inputs = [v.strip() for v in inputs_str.split(",") if v.strip()]
outputs = [v.strip() for v in outputs_str.split(",") if v.strip()]
for name_type in inputs:
name, type_ = _parse_named_type_node(name_type)
for name, type_ in _parse_arg_string(inputs_str):
fields[name] = (type_, InputField())
for name_type in outputs:
name, type_ = _parse_named_type_node(name_type)
for name, type_ in _parse_arg_string(outputs_str):
fields[name] = (type_, OutputField())

return fields


def _parse_named_type_node(node, names=None) -> Any:
parts = node.split(":")
if len(parts) == 1:
return parts[0], str
name, type_str = parts
type_ = _parse_type_node(ast.parse(type_str), names)
return name, type_
def _parse_arg_string(string: str, names=None) -> Dict[str, str]:
args = ast.parse("def f(" + string + "): pass").body[0].args.args
names = [arg.arg for arg in args]
types = [str if arg.annotation is None else _parse_type_node(arg.annotation) for arg in args]
return zip(names, types)


def _parse_type_node(node, names=None) -> Any:
Expand All @@ -306,7 +303,7 @@ def _parse_type_node(node, names=None) -> Any:
without using structural pattern matching introduced in Python 3.10.
"""
if names is None:
names = {}
names = typing.__dict__

if isinstance(node, ast.Module):
body = node.body
Expand All @@ -325,16 +322,23 @@ def _parse_type_node(node, names=None) -> Any:
for type_ in [int, str, float, bool, list, tuple, dict]:
if type_.__name__ == id_:
return type_
raise ValueError(f"Unknown name: {id_}")

elif isinstance(node, ast.Subscript):
if isinstance(node, ast.Subscript):
base_type = _parse_type_node(node.value, names)
arg_type = _parse_type_node(node.slice, names)
return base_type[arg_type]

elif isinstance(node, ast.Tuple):
if isinstance(node, ast.Tuple):
elts = node.elts
return tuple(_parse_type_node(elt, names) for elt in elts)

if isinstance(node, ast.Call):
if node.func.id == "Field":
keys = [kw.arg for kw in node.keywords]
values = [kw.value.value for kw in node.keywords]
return Field(**dict(zip(keys, values)))

raise ValueError(f"Code is not syntactically valid: {node}")


Expand Down
21 changes: 21 additions & 0 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,3 +781,24 @@ def _test_demos_missing_input():
Input: What is the capital of France?
Thoughts: My thoughts
Output: Paris""")


def test_conlist():
dspy.settings.configure(
lm=DummyLM(['{"value": []}', '{"value": [1]}', '{"value": [1, 2]}', '{"value": [1, 2, 3]}'])
)

@predictor
def make_numbers(input: str) -> Annotated[list[int], Field(min_items=2)]:
pass

assert make_numbers(input="What are the first two numbers?") == [1, 2]


def test_conlist2():
dspy.settings.configure(
lm=DummyLM(['{"value": []}', '{"value": [1]}', '{"value": [1, 2]}', '{"value": [1, 2, 3]}'])
)

make_numbers = TypedPredictor("input:str -> output:Annotated[List[int], Field(min_items=2)]")
assert make_numbers(input="What are the first two numbers?").output == [1, 2]

0 comments on commit 3c2edf9

Please sign in to comment.