Skip to content

Commit

Permalink
Don't call "make explain" in the last retry
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 17, 2024
1 parent 3d74de7 commit a9ee8ed
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ def _make_example(self, type_) -> str:
# TODO: Instead of using a language model to create the example, we can also just use a
# library like https://pypi.org/project/polyfactory/ that's made exactly to do this.

def _format_error(self, error: Exception, task_description: Union[str, FieldInfo], model_output: str) -> str:
def _format_error(
self,
error: Exception,
task_description: Union[str, FieldInfo],
model_output: str,
lm_explain: bool,
) -> str:
if isinstance(error, pydantic.ValidationError):
errors = []
for e in error.errors():
Expand All @@ -125,7 +131,7 @@ def _format_error(self, error: Exception, task_description: Union[str, FieldInfo
else:
error_text = repr(error)

if self.explain_errors:
if self.explain_errors and lm_explain:
if isinstance(task_description, FieldInfo):
args = task_description.json_schema_extra
task_description = args["prefix"] + " " + args["desc"]
Expand Down Expand Up @@ -153,7 +159,9 @@ class Signature(dspy.Signature):
language_model_output: str = dspy.InputField(desc="The output of the model")
error: str = dspy.InputField(desc="The validation error trigged by the models output")
explanation: str = dspy.OutputField(desc="Explain what the model did wrong")
advice: str = dspy.OutputField(desc="Instructions for the model to do better next time")
advice: str = dspy.OutputField(
desc="Instructions for the model to do better next time. A single paragraph."
)

# TODO: We could also try repair the output here. For example, if the output is a float, but the
# model returned a "float + explanation", the repair could be to remove the explanation.
Expand Down Expand Up @@ -273,7 +281,12 @@ def forward(self, **kwargs) -> dspy.Prediction:
parser = field.json_schema_extra.get("parser", lambda x: x)
parsed[name] = parser(value)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = self._format_error(e, signature.fields[name], value)
errors[name] = self._format_error(
e,
signature.fields[name],
value,
lm_explain=try_i + 1 < self.max_retries,
)

# If we can, we add an example to the error message
current_desc = field.json_schema_extra.get("desc", "")
Expand Down Expand Up @@ -307,6 +320,7 @@ def forward(self, **kwargs) -> dspy.Prediction:
"> " + field.json_schema_extra["prefix"] + " " + completion[name]
for name, field in signature.output_fields.items()
),
lm_explain=try_i + 1 < self.max_retries,
)
if errors:
# Add new fields for each error
Expand Down

0 comments on commit a9ee8ed

Please sign in to comment.