From bb074a71a807f92f2910055b180fe8cb66f10b1a Mon Sep 17 00:00:00 2001 From: Shangyint Date: Tue, 7 May 2024 23:31:44 -0700 Subject: [PATCH 1/2] fix: assertion and chain-of-thought bugs --- dspy/predict/retry.py | 7 ++++--- dspy/primitives/assertions.py | 7 +++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dspy/predict/retry.py b/dspy/predict/retry.py index f851addd3..38651186b 100644 --- a/dspy/predict/retry.py +++ b/dspy/predict/retry.py @@ -10,16 +10,17 @@ class Retry(Predict): def __init__(self, module): super().__init__(module.signature) self.module = module - self.original_signature = module.signature + self.original_signature = module.extended_signature if isinstance(module, dspy.ChainOfThought) else module.signature self.original_forward = module.forward self.new_signature = self._create_new_signature(self.original_signature) def _create_new_signature(self, signature): # Add "Past" input fields for each output field for key, value in signature.output_fields.items(): + actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":" signature = signature.append(f"past_{key}", dspy.InputField( - prefix="Past " + value.json_schema_extra["prefix"], - desc="past output with errors", + prefix="Previous " + actual_prefix, + desc=f"past {actual_prefix} with errors", format=value.json_schema_extra.get("format"), )) diff --git a/dspy/primitives/assertions.py b/dspy/primitives/assertions.py index 5d319437e..771525c3b 100644 --- a/dspy/primitives/assertions.py +++ b/dspy/primitives/assertions.py @@ -105,10 +105,10 @@ def __call__(self) -> Any: if self.result: return True elif dspy.settings.bypass_suggest: - dspy.logger.error(f"SuggestionFailed: {self.msg}") + dspy.logger.info(f"SuggestionFailed: {self.msg}") return True else: - dspy.logger.error(f"SuggestionFailed: {self.msg}") + dspy.logger.info(f"SuggestionFailed: {self.msg}") raise DSPySuggestionError( id=self.id, msg=self.msg, @@ -257,8 +257,7 @@ def wrapper(*args, **kwargs): ): dspy.settings.predictor_feedbacks[dspy.settings.backtrack_to].append(error_msg) - # assert isinstance(error_state[0].signature, dspy.Signature) - output_fields = error_state[0].signature.output_fields + output_fields = error_state[0].new_signature.output_fields past_outputs = {} for field_name in output_fields.keys(): past_outputs[field_name] = getattr( From be493e99a290bf0e1f15699404ba3859d331f503 Mon Sep 17 00:00:00 2001 From: Shangyint Date: Tue, 7 May 2024 23:41:11 -0700 Subject: [PATCH 2/2] update test_retry --- tests/predict/test_retry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/predict/test_retry.py b/tests/predict/test_retry.py index 1e24d191d..5ceee8a62 100644 --- a/tests/predict/test_retry.py +++ b/tests/predict/test_retry.py @@ -26,7 +26,7 @@ def test_retry_simple(): print(lm.get_convo(-1)) assert lm.get_convo(-1).endswith( "Question: What color is the sky?\n\n" - "Past Answer: red\n\n" + "Previous Answer: red\n\n" "Instructions: Try harder\n\n" "Answer: blue" ) @@ -61,7 +61,7 @@ def forward(self, **kwargs): print(lm.get_convo(-1)) assert lm.get_convo(-1).endswith( "Question: What color is the sky?\n\n" - "Past Answer: red\n\n" + "Previous Answer: red\n\n" "Instructions: Please think harder\n\n" "Answer: blue" ) @@ -105,7 +105,7 @@ def forward(self, **kwargs): assert result.answer == "blue" assert lm.get_convo(-1).endswith( 'Input: {"question":"What color is the sky?"}\n\n' - 'Past Output: {"answer":"red"}\n\n' + 'Previous Output: {"answer":"red"}\n\n' 'Instructions: Please think harder\n\n' 'Output: {"answer":"blue"}' )