Skip to content

Commit

Permalink
Richer output from signature optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 6, 2024
1 parent 26bb358 commit a82862a
Show file tree
Hide file tree
Showing 3 changed files with 2,205 additions and 91 deletions.
54 changes: 36 additions & 18 deletions dspy/teleprompt/signature_opt_typed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import textwrap
from dataclasses import dataclass
from typing import Generic, Literal, TypeVar

import pydantic
Expand Down Expand Up @@ -115,12 +116,20 @@ class GenerateSignature(dspy.Signature, Generic[T]):
score: float = OutputField(desc="The expected score for the new signature. Don't write anything after this number.")


@dataclass
class OptimizerResult:
program: dspy.Program
signatures: list[dict[str, Signature]]
scores: list[float]


def optimize_signature(
student,
evaluator,
n_iterations=10,
strategy: Literal["best", "last"] = "best",
sorted_order: Literal["increasing", "decreasing"] = "increasing",
sorted_order: Literal["increasing", "decreasing", "chronological"] = "increasing",
strategy: Literal["last", "best"] = "best",
max_examples=20,
# Formerly part of the constructor
prompt_model=None,
initial_prompts=2,
Expand All @@ -139,10 +148,12 @@ def optimize_signature(
The evaluator to use to score the program.
n_iterations : int, optional
The number of iterations to run, by default 10
strategy : Literal["best", "last"], optional
The strategy to use to select the final program, by default "best"
sorted_order : Literal["increasing", "decreasing"], optional
max_examples : int, optional
The maximum number of examples to use for the evaluator, by default 20
sorted_order : Literal["increasing", "decreasing", "chronological"], optional
The order in which to sort the scores, by default "increasing"
strategy : Literal["last", "best"], optional
The strategy to use to select the final program, by default "best"
prompt_model : dspy.LanguageModel, optional
The language model to use to generate prompts, by default None
initial_prompts : int, optional
Expand Down Expand Up @@ -224,14 +235,17 @@ def optimize_signature(
SignatureInfo = type(candidates[name][0]) # noqa: N806
generator = TypedPredictor(GenerateSignature[SignatureInfo])

demos = [
dspy.Example(
proposed_signature=info,
score=sc,
)
for info, sc in zip(candidates[name], scores)
]
demos.sort(key=(lambda x: x.score), reverse=(sorted_order == "decreasing"))
demos = [dspy.Example(proposed_signature=info, score=sc) for info, sc in zip(candidates[name], scores)]
if sorted_order == "chronological":
demos = demos[-max_examples:]
elif sorted_order == "increasing":
demos.sort(key=(lambda x: x.score), reverse=False)
demos = demos[-max_examples:]
elif sorted_order == "decreasing":
demos.sort(key=(lambda x: x.score), reverse=True)
demos = demos[:max_examples]
else:
raise ValueError(f"Invalid sorted_order: {sorted_order}")
generator.predictor.demos = demos

if verbose:
Expand All @@ -240,12 +254,16 @@ def optimize_signature(
candidates[name].append(new_signature)

if strategy == "last":
return module

if strategy == "best":
pass
elif strategy == "best":
i = scores.index(max(scores))
for name, p in named_predictors:
p.signature = candidates[name][i].to_signature()
return module
else:
raise ValueError(f"Invalid strategy: {strategy}")

raise ValueError(f"Invalid strategy: {strategy}")
return OptimizerResult(
program=module,
signatures=[{name: sigs[i].to_signature()} for name, sigs in candidates.items() for i in range(n_iterations)],
scores=scores,
)
Loading

0 comments on commit a82862a

Please sign in to comment.