Skip to content

Commit

Permalink
Support for n=... and type-strings
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 3, 2024
1 parent 6af5366 commit 9a1c189
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 55 deletions.
61 changes: 34 additions & 27 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import inspect
import os
import openai
Expand All @@ -7,6 +8,7 @@
from typing import Annotated, List, Tuple # noqa: UP035
from dsp.templates import passages2text
import json
from dspy.primitives.prediction import Prediction

from dspy.signatures.signature import ensure_signature

Expand Down Expand Up @@ -71,7 +73,7 @@ def TypedChainOfThought(signature) -> dspy.Module: # noqa: N802
class TypedPredictor(dspy.Module):
def __init__(self, signature):
super().__init__()
self.signature = signature
self.signature = ensure_signature(signature)
self.predictor = dspy.Predict(signature)

def copy(self) -> "TypedPredictor":
Expand Down Expand Up @@ -127,8 +129,7 @@ def _prepare_signature(self) -> dspy.Signature:
name,
desc=field.json_schema_extra.get("desc", "")
+ (
". Respond with a single JSON object. JSON Schema: "
+ json.dumps(type_.model_json_schema())
". Respond with a single JSON object. JSON Schema: " + json.dumps(type_.model_json_schema())
),
format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)),
parser=lambda x, from_json=from_json: from_json(_unwrap_json(x)),
Expand All @@ -152,28 +153,33 @@ def forward(self, **kwargs) -> dspy.Prediction:
for try_i in range(MAX_RETRIES):
result = self.predictor(**modified_kwargs, new_signature=signature)
errors = {}
parsed_results = {}
parsed_results = defaultdict(list)
# Parse the outputs
for name, field in signature.output_fields.items():
try:
value = getattr(result, name)
parser = field.json_schema_extra.get("parser", lambda x: x)
parsed_results[name] = parser(value)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = _format_error(e)
# If we can, we add an example to the error message
current_desc = field.json_schema_extra.get("desc", "")
i = current_desc.find("JSON Schema: ")
if i == -1:
continue # Only add examples to JSON objects
suffix, current_desc = current_desc[i:], current_desc[:i]
prefix = "You MUST use this format: "
if try_i + 1 < MAX_RETRIES \
and prefix not in current_desc \
and (example := self._make_example(field.annotation)):
signature = signature.with_updated_fields(
name, desc=current_desc + "\n" + prefix + example + "\n" + suffix,
)
for i, completion in enumerate(result.completions):
try:
value = completion[name]
parser = field.json_schema_extra.get("parser", lambda x: x)
completion[name] = parser(value)
parsed_results[name].append(parser(value))
except (pydantic.ValidationError, ValueError) as e:
errors[name] = _format_error(e)
# If we can, we add an example to the error message
current_desc = field.json_schema_extra.get("desc", "")
i = current_desc.find("JSON Schema: ")
if i == -1:
continue # Only add examples to JSON objects
suffix, current_desc = current_desc[i:], current_desc[:i]
prefix = "You MUST use this format: "
if (
try_i + 1 < MAX_RETRIES
and prefix not in current_desc
and (example := self._make_example(field.annotation))
):
signature = signature.with_updated_fields(
name,
desc=current_desc + "\n" + prefix + example + "\n" + suffix,
)
if errors:
# Add new fields for each error
for name, error in errors.items():
Expand All @@ -187,11 +193,12 @@ def forward(self, **kwargs) -> dspy.Prediction:
)
else:
# If there are no errors, we return the parsed results
for name, value in parsed_results.items():
setattr(result, name, value)
return result
# for name, value in parsed_results.items():
# setattr(result, name, value)
return Prediction.from_completions(parsed_results)
raise ValueError(
"Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", errors,
"Too many retries trying to get the correct output format. " + "Try simplifying the requirements.",
errors,
)


Expand Down
20 changes: 10 additions & 10 deletions dspy/primitives/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,29 @@
class Prediction(Example):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

del self._demos
del self._input_keys

self._completions = None

@classmethod
def from_completions(cls, list_or_dict, signature=None):
obj = cls()
obj._completions = Completions(list_or_dict, signature=signature)
obj._store = {k: v[0] for k, v in obj._completions.items()}

return obj

def __repr__(self):
store_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._store.items())
store_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._store.items())

if self._completions is None or len(self._completions) == 1:
return f"Prediction(\n {store_repr}\n)"

num_completions = len(self._completions)
return f"Prediction(\n {store_repr},\n completions=Completions(...)\n) ({num_completions-1} completions omitted)"

def __str__(self):
return self.__repr__()

Expand Down Expand Up @@ -62,15 +62,15 @@ def __getitem__(self, key):
if isinstance(key, int):
if key < 0 or key >= len(self):
raise IndexError("Index out of range")

return Prediction(**{k: v[key] for k, v in self._completions.items()})

return self._completions[key]

def __getattr__(self, name):
if name in self._completions:
return self._completions[name]

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __len__(self):
Expand All @@ -82,7 +82,7 @@ def __contains__(self, key):
return key in self._completions

def __repr__(self):
items_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._completions.items())
items_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._completions.items())
return f"Completions(\n {items_repr}\n)"

def __str__(self):
Expand Down
53 changes: 47 additions & 6 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import ast
from copy import deepcopy
import typing
import dsp
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo
from typing import Type, Union, Dict, Tuple # noqa: UP035
from typing import Any, Type, Union, Dict, Tuple # noqa: UP035
import re

from dspy.signatures.field import InputField, OutputField, new_to_old_field
Expand Down Expand Up @@ -254,22 +255,62 @@ def make_signature(


def _parse_signature(signature: str) -> Tuple[Type, Field]:
pattern = r"^\s*[\w\s,]+\s*->\s*[\w\s,]+\s*$"
pattern = r"^\s*[\w\s,:]+\s*->\s*[\w\s,:]+\s*$"
if not re.match(pattern, signature):
raise ValueError(f"Invalid signature format: '{signature}'")

fields = {}
inputs_str, outputs_str = map(str.strip, signature.split("->"))
inputs = [v.strip() for v in inputs_str.split(",") if v.strip()]
outputs = [v.strip() for v in outputs_str.split(",") if v.strip()]
for name in inputs:
fields[name] = (str, InputField())
for name in outputs:
fields[name] = (str, OutputField())
for name_type in inputs:
name, type_ = _parse_named_type_node(name_type)
fields[name] = (type_, InputField())
for name_type in outputs:
name, type_ = _parse_named_type_node(name_type)
fields[name] = (type_, OutputField())

return fields


def _parse_named_type_node(node, names=None) -> Any:
parts = node.split(":")
if len(parts) == 1:
return parts[0], str
name, type_str = parts
type_ = _parse_type_node(ast.parse(type_str), names)
return name, type_


def _parse_type_node(node, names=None) -> Any:
"""Recursively parse an AST node representing a type annotation.
using structural pattern matching introduced in Python 3.10.
"""
if names is None:
names = {}
match node:
case ast.Module(body=body):
if len(body) != 1:
raise ValueError(f"Code is not syntactically valid: {node}")
return _parse_type_node(body[0], names)
case ast.Expr(value=value):
return _parse_type_node(value, names)
case ast.Name(id=id):
if id in names:
return names[id]
for type_ in [int, str, float, bool, list, tuple, dict]:
if type_.__name__ == id:
return type_
case ast.Subscript(value=value, slice=slice):
base_type = _parse_type_node(value, names)
arg_type = _parse_type_node(slice, names)
return base_type[arg_type]
case ast.Tuple(elts=elts):
return tuple(_parse_type_node(elt, names) for elt in elts)
raise ValueError(f"Code is not syntactically valid: {node}")


def infer_prefix(attribute_name: str) -> str:
"""Infer a prefix from an attribute name."""
# Convert camelCase to snake_case, but handle sequences of capital letters properly
Expand Down
51 changes: 39 additions & 12 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def hard_questions(topics: List[str]) -> List[str]:
pass

expected = ["What is the speed of light?", "What is the speed of sound?"]
lm = DummyLM(
['{"value": ["What is the speed of light?", "What is the speed of sound?"]}']
)
lm = DummyLM(['{"value": ["What is the speed of light?", "What is the speed of sound?"]}'])
dspy.settings.configure(lm=lm)

question = hard_questions(topics=["Physics", "Music"])
Expand Down Expand Up @@ -88,9 +86,7 @@ def test_simple_class():
class Answer(pydantic.BaseModel):
value: float
certainty: float
comments: List[str] = pydantic.Field(
description="At least two comments about the answer"
)
comments: List[str] = pydantic.Field(description="At least two comments about the answer")

class QA(FunctionalModule):
@predictor
Expand Down Expand Up @@ -229,9 +225,7 @@ def simple_metric(example, prediction, trace=None):
lm = DummyLM(["blue", "Ring-ding-ding-ding-dingeringeding!"], follow_examples=True)
dspy.settings.configure(lm=lm, trace=[])

bootstrap = BootstrapFewShot(
metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1
)
bootstrap = BootstrapFewShot(metric=simple_metric, max_bootstrapped_demos=1, max_labeled_demos=1)
compiled_student = bootstrap.compile(student, teacher=teacher, trainset=trainset)

lm.inspect_history(n=2)
Expand Down Expand Up @@ -295,7 +289,7 @@ def flight_information(email: str) -> TravelInformation:
# Example with a bad origin code.
'{"origin": "JF0", "destination": "LAX", "date": "2022-12-25"}',
# Example to help the model understand
'{...}',
"{...}",
# Fixed
'{"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}',
]
Expand Down Expand Up @@ -344,9 +338,9 @@ def flight_information(email: str) -> TravelInformation:
[
# First origin is wrong, then destination, then all is good
'{"origin": "JF0", "destination": "LAX", "date": "2022-12-25"}',
'{...}', # Example to help the model understand
"{...}", # Example to help the model understand
'{"origin": "JFK", "destination": "LA0", "date": "2022-12-25"}',
'{...}', # Example to help the model understand
"{...}", # Example to help the model understand
'{"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}',
]
)
Expand Down Expand Up @@ -447,3 +441,36 @@ def test(input: Annotated[str, Field(description="description")]) -> Annotated[f
output = test(input="input")

assert output == 0.5


def test_multiple_outputs():
lm = DummyLM([str(i) for i in range(100)])
dspy.settings.configure(lm=lm)

test = TypedPredictor("input -> output")
output = test(input="input", config=dict(n=3)).completions.output
assert output == ["0", "1", "2"]


def test_multiple_outputs_int():
lm = DummyLM([str(i) for i in range(100)])
dspy.settings.configure(lm=lm)

class TestSignature(dspy.Signature):
input: int = dspy.InputField()
output: int = dspy.OutputField()

test = TypedPredictor(TestSignature)

output = test(input=8, config=dict(n=3)).completions.output
assert output == [0, 1, 2]


def test_parse_type_string():
lm = DummyLM([str(i) for i in range(100)])
dspy.settings.configure(lm=lm)

test = TypedPredictor("input:int -> output:int")

output = test(input=8, config=dict(n=3)).completions.output
assert output == [0, 1, 2]

0 comments on commit 9a1c189

Please sign in to comment.