Skip to content

Commit

Permalink
Merge pull request stanfordnlp#540 from thomasahle/main
Browse files Browse the repository at this point in the history
Improvements to Signature
  • Loading branch information
thomasahle authored Mar 3, 2024
2 parents 16fdf50 + 6da72e3 commit 71ac993
Show file tree
Hide file tree
Showing 7 changed files with 384 additions and 220 deletions.
48 changes: 30 additions & 18 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import inspect
import os
import openai
Expand All @@ -7,8 +8,9 @@
from typing import Annotated, List, Tuple # noqa: UP035
from dsp.templates import passages2text
import json
from dspy.primitives.prediction import Prediction

from dspy.signatures.signature import ensure_signature
from dspy.signatures.signature import ensure_signature, make_signature


MAX_RETRIES = 3
Expand Down Expand Up @@ -71,7 +73,7 @@ def TypedChainOfThought(signature) -> dspy.Module: # noqa: N802
class TypedPredictor(dspy.Module):
def __init__(self, signature):
super().__init__()
self.signature = signature
self.signature = ensure_signature(signature)
self.predictor = dspy.Predict(signature)

def copy(self) -> "TypedPredictor":
Expand All @@ -81,7 +83,7 @@ def copy(self) -> "TypedPredictor":
def _make_example(type_) -> str:
# Note: DSPy will cache this call so we only pay the first time TypedPredictor is called.
json_object = dspy.Predict(
dspy.Signature(
make_signature(
"json_schema -> json_object",
"Make a very succinct json object that validates with the following schema",
),
Expand Down Expand Up @@ -127,8 +129,7 @@ def _prepare_signature(self) -> dspy.Signature:
name,
desc=field.json_schema_extra.get("desc", "")
+ (
". Respond with a single JSON object. JSON Schema: "
+ json.dumps(type_.model_json_schema())
". Respond with a single JSON object. JSON Schema: " + json.dumps(type_.model_json_schema())
),
format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)),
parser=lambda x, from_json=from_json: from_json(_unwrap_json(x)),
Expand All @@ -152,13 +153,20 @@ def forward(self, **kwargs) -> dspy.Prediction:
for try_i in range(MAX_RETRIES):
result = self.predictor(**modified_kwargs, new_signature=signature)
errors = {}
parsed_results = {}
parsed_results = []
# Parse the outputs
for name, field in signature.output_fields.items():
for i, completion in enumerate(result.completions):
try:
value = getattr(result, name)
parser = field.json_schema_extra.get("parser", lambda x: x)
parsed_results[name] = parser(value)
parsed = {}
for name, field in signature.output_fields.items():
value = completion[name]
parser = field.json_schema_extra.get("parser", lambda x: x)
completion[name] = parser(value)
parsed[name] = parser(value)
# Instantiate the actual signature with the parsed values.
# This allow pydantic to validate the fields defined in the signature.
_dummy = self.signature(**kwargs, **parsed)
parsed_results.append(parsed)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = _format_error(e)
# If we can, we add an example to the error message
Expand All @@ -168,11 +176,14 @@ def forward(self, **kwargs) -> dspy.Prediction:
continue # Only add examples to JSON objects
suffix, current_desc = current_desc[i:], current_desc[:i]
prefix = "You MUST use this format: "
if try_i + 1 < MAX_RETRIES \
and prefix not in current_desc \
and (example := self._make_example(field.annotation)):
if (
try_i + 1 < MAX_RETRIES
and prefix not in current_desc
and (example := self._make_example(field.annotation))
):
signature = signature.with_updated_fields(
name, desc=current_desc + "\n" + prefix + example + "\n" + suffix,
name,
desc=current_desc + "\n" + prefix + example + "\n" + suffix,
)
if errors:
# Add new fields for each error
Expand All @@ -187,11 +198,12 @@ def forward(self, **kwargs) -> dspy.Prediction:
)
else:
# If there are no errors, we return the parsed results
for name, value in parsed_results.items():
setattr(result, name, value)
return result
return Prediction.from_completions(
{key: [r[key] for r in parsed_results] for key in signature.output_fields}
)
raise ValueError(
"Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", errors,
"Too many retries trying to get the correct output format. " + "Try simplifying the requirements.",
errors,
)


Expand Down
20 changes: 10 additions & 10 deletions dspy/primitives/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,29 @@
class Prediction(Example):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

del self._demos
del self._input_keys

self._completions = None

@classmethod
def from_completions(cls, list_or_dict, signature=None):
obj = cls()
obj._completions = Completions(list_or_dict, signature=signature)
obj._store = {k: v[0] for k, v in obj._completions.items()}

return obj

def __repr__(self):
store_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._store.items())
store_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._store.items())

if self._completions is None or len(self._completions) == 1:
return f"Prediction(\n {store_repr}\n)"

num_completions = len(self._completions)
return f"Prediction(\n {store_repr},\n completions=Completions(...)\n) ({num_completions-1} completions omitted)"

def __str__(self):
return self.__repr__()

Expand Down Expand Up @@ -62,15 +62,15 @@ def __getitem__(self, key):
if isinstance(key, int):
if key < 0 or key >= len(self):
raise IndexError("Index out of range")

return Prediction(**{k: v[key] for k, v in self._completions.items()})

return self._completions[key]

def __getattr__(self, name):
if name in self._completions:
return self._completions[name]

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __len__(self):
Expand All @@ -82,7 +82,7 @@ def __contains__(self, key):
return key in self._completions

def __repr__(self):
items_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._completions.items())
items_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._completions.items())
return f"Completions(\n {items_repr}\n)"

def __str__(self):
Expand Down
Loading

0 comments on commit 71ac993

Please sign in to comment.