diff --git a/dspy/primitives/assertions.py b/dspy/primitives/assertions.py index 2dcae4ff7..5f89896e2 100644 --- a/dspy/primitives/assertions.py +++ b/dspy/primitives/assertions.py @@ -40,11 +40,17 @@ class DSPyAssertionError(AssertionError): """Custom exception raised when a DSPy `Assert` fails.""" def __init__( - self, id: str, msg: str, state: Any = None, is_metric: bool = False + self, + id: str, + msg: str, + target_module: Any = None, + state: Any = None, + is_metric: bool = False, ) -> None: super().__init__(msg) self.id = id self.msg = msg + self.target_module = target_module self.state = state self.is_metric = is_metric @@ -72,6 +78,7 @@ def __init__( class Constraint: + def __init__( self, result: bool, msg: str = "", target_module=None, is_metric: bool = False ): @@ -99,6 +106,7 @@ def __call__(self) -> bool: raise DSPyAssertionError( id=self.id, msg=self.msg, + target_module=self.target_module, state=dsp.settings.trace, is_metric=self.is_metric, ) @@ -236,7 +244,7 @@ def wrapper(*args, **kwargs): except (DSPySuggestionError, DSPyAssertionError) as e: if not current_error: current_error = e - suggest_id, error_msg, suggest_target_module, error_state = ( + error_id, error_msg, error_target_module, error_state = ( e.id, e.msg, e.target_module, @@ -250,11 +258,11 @@ def wrapper(*args, **kwargs): dspy.settings.assert_failures += 1 if dsp.settings.trace: - if suggest_target_module: + if error_target_module: for i in range(len(dsp.settings.trace) - 1, -1, -1): trace_element = dsp.settings.trace[i] mod = trace_element[0] - if mod.signature == suggest_target_module: + if mod.signature == error_target_module: dspy.settings.backtrack_to = mod break else: @@ -336,9 +344,13 @@ def assert_transform_module( module ) - if all(map(lambda p: isinstance(p[1], dspy.retry.Retry), module.named_predictors())): - pass # we already applied the Retry mapping outside - elif all(map(lambda p: not isinstance(p[1], dspy.retry.Retry), module.named_predictors())): + if all( + map(lambda p: isinstance(p[1], dspy.retry.Retry), module.named_predictors()) + ): + pass # we already applied the Retry mapping outside + elif all( + map(lambda p: not isinstance(p[1], dspy.retry.Retry), module.named_predictors()) + ): module.map_named_predictors(dspy.retry.Retry) else: raise RuntimeError("Module has mixed predictors, can't apply Retry mapping.")