Skip to content

Commit

Permalink
Merge pull request stanfordnlp#693 from Neoxelox/fix-retry-typed-pred…
Browse files Browse the repository at this point in the history
…ictors

Fix Retry for TypedPredictors
  • Loading branch information
thomasahle authored Mar 26, 2024
2 parents c4e72be + 40eb491 commit 49bf927
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
10 changes: 8 additions & 2 deletions dspy/predict/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def _create_new_signature(self, signature):
return signature

def forward(self, *, past_outputs, **kwargs):
# Take into account the possible new signature, as in TypedPredictor
new_signature = kwargs.pop("new_signature", None)
if new_signature:
self.original_signature = new_signature
self.new_signature = self._create_new_signature(self.original_signature)

# Convert the dict past_outputs={"answer": ...} to kwargs
# {past_answer=..., ...}
for key, value in past_outputs.items():
Expand All @@ -42,7 +48,7 @@ def forward(self, *, past_outputs, **kwargs):
# Note: This only works if the wrapped module is a Predict or ChainOfThought.
kwargs["new_signature"] = self.new_signature
return self.original_forward(**kwargs)

def __call__(self, **kwargs):
cached_kwargs = copy.deepcopy(kwargs)
kwargs["_trace"] = False
Expand All @@ -59,7 +65,7 @@ def __call__(self, **kwargs):
# now pop multiple reserved keys
# NOTE(shangyin) past_outputs seems not useful to include in demos,
# therefore dropped
for key in ["_trace", "demos", "signature", "config", "lm", "past_outputs"]:
for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]:
kwargs.pop(key, None)

if dsp.settings.trace is not None:
Expand Down
45 changes: 45 additions & 0 deletions tests/predict/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dspy
from dspy.utils import DummyLM
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
import pydantic


def test_retry_simple():
Expand Down Expand Up @@ -64,3 +65,47 @@ def forward(self, **kwargs):
"Instructions: Please think harder\n\n"
"Answer: blue"
)


def test_retry_forward_with_typed_predictor():
# First we make a mistake, then we fix it
lm = DummyLM(['{"answer":"red"}', '{"answer":"blue"}'])
dspy.settings.configure(lm=lm, trace=[])

class AnswerQuestion(dspy.Signature):
"""Answer questions with succint responses."""

class Input(pydantic.BaseModel):
question: str

class Output(pydantic.BaseModel):
answer: str

input: Input = dspy.InputField()
output: Output = dspy.OutputField()

class QuestionAnswerer(dspy.Module):
def __init__(self):
super().__init__()
self.answer_question = dspy.TypedPredictor(AnswerQuestion)

def forward(self, **kwargs):
result = self.answer_question(input=AnswerQuestion.Input(**kwargs)).output
dspy.Suggest(result.answer == "blue", "Please think harder")
return result

program = QuestionAnswerer()
program = assert_transform_module(
program.map_named_predictors(dspy.Retry),
functools.partial(backtrack_handler, max_backtracks=1),
)

result = program(question="What color is the sky?")

assert result.answer == "blue"
assert lm.get_convo(-1).endswith(
'Input: {"question":"What color is the sky?"}\n\n'
'Past Output: {"answer":"red"}\n\n'
'Instructions: Please think harder\n\n'
'Output: {"answer":"blue"}'
)

0 comments on commit 49bf927

Please sign in to comment.