Skip to content

Commit

Permalink
Merge pull request stanfordnlp#393 from stanfordnlp/fix-dspy.assert
Browse files Browse the repository at this point in the history
fix dspy.assert backtracking
  • Loading branch information
Shangyint authored Feb 14, 2024
2 parents 8a457cc + 6e7be39 commit 0306ef7
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions dspy/primitives/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@ class DSPyAssertionError(AssertionError):
"""Custom exception raised when a DSPy `Assert` fails."""

def __init__(
self, id: str, msg: str, state: Any = None, is_metric: bool = False
self,
id: str,
msg: str,
target_module: Any = None,
state: Any = None,
is_metric: bool = False,
) -> None:
super().__init__(msg)
self.id = id
self.msg = msg
self.target_module = target_module
self.state = state
self.is_metric = is_metric

Expand Down Expand Up @@ -72,6 +78,7 @@ def __init__(


class Constraint:

def __init__(
self, result: bool, msg: str = "", target_module=None, is_metric: bool = False
):
Expand Down Expand Up @@ -99,6 +106,7 @@ def __call__(self) -> bool:
raise DSPyAssertionError(
id=self.id,
msg=self.msg,
target_module=self.target_module,
state=dsp.settings.trace,
is_metric=self.is_metric,
)
Expand Down Expand Up @@ -236,7 +244,7 @@ def wrapper(*args, **kwargs):
except (DSPySuggestionError, DSPyAssertionError) as e:
if not current_error:
current_error = e
suggest_id, error_msg, suggest_target_module, error_state = (
error_id, error_msg, error_target_module, error_state = (
e.id,
e.msg,
e.target_module,
Expand All @@ -250,11 +258,11 @@ def wrapper(*args, **kwargs):
dspy.settings.assert_failures += 1

if dsp.settings.trace:
if suggest_target_module:
if error_target_module:
for i in range(len(dsp.settings.trace) - 1, -1, -1):
trace_element = dsp.settings.trace[i]
mod = trace_element[0]
if mod.signature == suggest_target_module:
if mod.signature == error_target_module:
dspy.settings.backtrack_to = mod
break
else:
Expand Down Expand Up @@ -336,9 +344,13 @@ def assert_transform_module(
module
)

if all(map(lambda p: isinstance(p[1], dspy.retry.Retry), module.named_predictors())):
pass # we already applied the Retry mapping outside
elif all(map(lambda p: not isinstance(p[1], dspy.retry.Retry), module.named_predictors())):
if all(
map(lambda p: isinstance(p[1], dspy.retry.Retry), module.named_predictors())
):
pass # we already applied the Retry mapping outside
elif all(
map(lambda p: not isinstance(p[1], dspy.retry.Retry), module.named_predictors())
):
module.map_named_predictors(dspy.retry.Retry)
else:
raise RuntimeError("Module has mixed predictors, can't apply Retry mapping.")
Expand Down

0 comments on commit 0306ef7

Please sign in to comment.