Skip to content

Commit

Permalink
Merge pull request stanfordnlp#995 from stanfordnlp/fix-assertions
Browse files Browse the repository at this point in the history
fix: assertion and chain-of-thought
  • Loading branch information
Shangyint authored May 9, 2024
2 parents b05394b + be493e9 commit 339cb60
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
7 changes: 4 additions & 3 deletions dspy/predict/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@ class Retry(Predict):
def __init__(self, module):
super().__init__(module.signature)
self.module = module
self.original_signature = module.signature
self.original_signature = module.extended_signature if isinstance(module, dspy.ChainOfThought) else module.signature
self.original_forward = module.forward
self.new_signature = self._create_new_signature(self.original_signature)

def _create_new_signature(self, signature):
# Add "Past" input fields for each output field
for key, value in signature.output_fields.items():
actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":"
signature = signature.append(f"past_{key}", dspy.InputField(
prefix="Past " + value.json_schema_extra["prefix"],
desc="past output with errors",
prefix="Previous " + actual_prefix,
desc=f"past {actual_prefix} with errors",
format=value.json_schema_extra.get("format"),
))

Expand Down
7 changes: 3 additions & 4 deletions dspy/primitives/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def __call__(self) -> Any:
if self.result:
return True
elif dspy.settings.bypass_suggest:
dspy.logger.error(f"SuggestionFailed: {self.msg}")
dspy.logger.info(f"SuggestionFailed: {self.msg}")
return True
else:
dspy.logger.error(f"SuggestionFailed: {self.msg}")
dspy.logger.info(f"SuggestionFailed: {self.msg}")
raise DSPySuggestionError(
id=self.id,
msg=self.msg,
Expand Down Expand Up @@ -257,8 +257,7 @@ def wrapper(*args, **kwargs):
):
dspy.settings.predictor_feedbacks[dspy.settings.backtrack_to].append(error_msg)

# assert isinstance(error_state[0].signature, dspy.Signature)
output_fields = error_state[0].signature.output_fields
output_fields = error_state[0].new_signature.output_fields
past_outputs = {}
for field_name in output_fields.keys():
past_outputs[field_name] = getattr(
Expand Down
6 changes: 3 additions & 3 deletions tests/predict/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_retry_simple():
print(lm.get_convo(-1))
assert lm.get_convo(-1).endswith(
"Question: What color is the sky?\n\n"
"Past Answer: red\n\n"
"Previous Answer: red\n\n"
"Instructions: Try harder\n\n"
"Answer: blue"
)
Expand Down Expand Up @@ -61,7 +61,7 @@ def forward(self, **kwargs):
print(lm.get_convo(-1))
assert lm.get_convo(-1).endswith(
"Question: What color is the sky?\n\n"
"Past Answer: red\n\n"
"Previous Answer: red\n\n"
"Instructions: Please think harder\n\n"
"Answer: blue"
)
Expand Down Expand Up @@ -105,7 +105,7 @@ def forward(self, **kwargs):
assert result.answer == "blue"
assert lm.get_convo(-1).endswith(
'Input: {"question":"What color is the sky?"}\n\n'
'Past Output: {"answer":"red"}\n\n'
'Previous Output: {"answer":"red"}\n\n'
'Instructions: Please think harder\n\n'
'Output: {"answer":"blue"}'
)

0 comments on commit 339cb60

Please sign in to comment.