diff --git a/dspy/predict/retry.py b/dspy/predict/retry.py index dcef1488f..f851addd3 100644 --- a/dspy/predict/retry.py +++ b/dspy/predict/retry.py @@ -32,6 +32,12 @@ def _create_new_signature(self, signature): return signature def forward(self, *, past_outputs, **kwargs): + # Take into account the possible new signature, as in TypedPredictor + new_signature = kwargs.pop("new_signature", None) + if new_signature: + self.original_signature = new_signature + self.new_signature = self._create_new_signature(self.original_signature) + # Convert the dict past_outputs={"answer": ...} to kwargs # {past_answer=..., ...} for key, value in past_outputs.items(): @@ -42,7 +48,7 @@ def forward(self, *, past_outputs, **kwargs): # Note: This only works if the wrapped module is a Predict or ChainOfThought. kwargs["new_signature"] = self.new_signature return self.original_forward(**kwargs) - + def __call__(self, **kwargs): cached_kwargs = copy.deepcopy(kwargs) kwargs["_trace"] = False @@ -59,7 +65,7 @@ def __call__(self, **kwargs): # now pop multiple reserved keys # NOTE(shangyin) past_outputs seems not useful to include in demos, # therefore dropped - for key in ["_trace", "demos", "signature", "config", "lm", "past_outputs"]: + for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]: kwargs.pop(key, None) if dsp.settings.trace is not None: diff --git a/tests/predict/test_retry.py b/tests/predict/test_retry.py index a125dde29..1e24d191d 100644 --- a/tests/predict/test_retry.py +++ b/tests/predict/test_retry.py @@ -2,6 +2,7 @@ import dspy from dspy.utils import DummyLM from dspy.primitives.assertions import assert_transform_module, backtrack_handler +import pydantic def test_retry_simple(): @@ -64,3 +65,47 @@ def forward(self, **kwargs): "Instructions: Please think harder\n\n" "Answer: blue" ) + + +def test_retry_forward_with_typed_predictor(): + # First we make a mistake, then we fix it + lm = DummyLM(['{"answer":"red"}', '{"answer":"blue"}']) + dspy.settings.configure(lm=lm, trace=[]) + + class AnswerQuestion(dspy.Signature): + """Answer questions with succint responses.""" + + class Input(pydantic.BaseModel): + question: str + + class Output(pydantic.BaseModel): + answer: str + + input: Input = dspy.InputField() + output: Output = dspy.OutputField() + + class QuestionAnswerer(dspy.Module): + def __init__(self): + super().__init__() + self.answer_question = dspy.TypedPredictor(AnswerQuestion) + + def forward(self, **kwargs): + result = self.answer_question(input=AnswerQuestion.Input(**kwargs)).output + dspy.Suggest(result.answer == "blue", "Please think harder") + return result + + program = QuestionAnswerer() + program = assert_transform_module( + program.map_named_predictors(dspy.Retry), + functools.partial(backtrack_handler, max_backtracks=1), + ) + + result = program(question="What color is the sky?") + + assert result.answer == "blue" + assert lm.get_convo(-1).endswith( + 'Input: {"question":"What color is the sky?"}\n\n' + 'Past Output: {"answer":"red"}\n\n' + 'Instructions: Please think harder\n\n' + 'Output: {"answer":"blue"}' + )