Skip to content

Commit

Permalink
Created typed signature optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 6, 2024
1 parent d633a77 commit 7283847
Show file tree
Hide file tree
Showing 11 changed files with 902 additions and 355 deletions.
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .retrieve import *
from .predict import *
from .primitives import *
from .functional import *

# from .evaluation import *

Expand Down
22 changes: 17 additions & 5 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,32 @@ def _prepare_signature(self) -> dspy.Signature:
format=lambda x: x if isinstance(x, str) else str(x),
parser=type_,
)
elif False:
# TODO: I don't like forcing the model to write "value" in the output.
if not (inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel)):
type_ = pydantic.create_model("Output", value=(type_, ...), __base__=pydantic.BaseModel)
to_json = lambda x, type_=type_: type_(value=x).model_dump_json()[9:-1] # {"value":"123"}
from_json = lambda x, type_=type_: type_.model_validate_json('{"value":' + x + "}").value
schema = json.dumps(type_.model_json_schema()["properties"]["value"])
else:
to_json = lambda x: x.model_dump_json()
from_json = lambda x, type_=type_: type_.model_validate_json(x)
schema = json.dumps(type_.model_json_schema())
else:
# Anything else we wrap in a pydantic object
to_json = lambda x: x.model_dump_json()
from_json = lambda x, type_=type_: type_.model_validate_json(x)
if not (inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel)):
type_ = pydantic.create_model("Output", value=(type_, ...), __base__=pydantic.BaseModel)
to_json = lambda x, type_=type_: type_(value=x).model_dump_json()
from_json = lambda x, type_=type_: type_.model_validate_json(x).value
schema = json.dumps(type_.model_json_schema())
else:
to_json = lambda x: x.model_dump_json()
from_json = lambda x, type_=type_: type_.model_validate_json(x)
schema = json.dumps(type_.model_json_schema())
signature = signature.with_updated_fields(
name,
desc=field.json_schema_extra.get("desc", "")
+ (
". Respond with a single JSON object. JSON Schema: " + json.dumps(type_.model_json_schema())
),
+ (". Respond with a single JSON object. JSON Schema: " + schema),
format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)),
parser=lambda x, from_json=from_json: from_json(_unwrap_json(x)),
type_=type_,
Expand Down
7 changes: 3 additions & 4 deletions dspy/primitives/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,11 @@ def add_parameter(param_name, param_value):

return named_parameters

def named_sub_modules(self) -> Generator[tuple[str, "BaseModule"], None, None]:
yield "", self
def named_sub_modules(self, root_name="base") -> Generator[tuple[str, "BaseModule"], None, None]:
yield root_name, self
for name, value in self.__dict__.items():
if isinstance(value, BaseModule):
for sub_name, sub_value in value.named_sub_modules():
yield f"{name}.{sub_name}", sub_value
yield from value.named_sub_modules(root_name=f"{root_name}.{name}")

def parameters(self):
return [param for _, param in self.named_parameters()]
Expand Down
46 changes: 30 additions & 16 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804
# Let Pydantic do its thing
cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs)

# If we don't have instructions, it might be because we are a derived generic type.
# In that case, we should inherit the instructions from the base class.
if cls.__doc__ is None:
for base in bases:
if isinstance(base, SignatureMeta):
doc = getattr(base, "__doc__", "")
if doc != "":
cls.__doc__ = doc

# The more likely case is that the user has just not given us a type.
# In that case, we should default to the input/output format.
if cls.__doc__ is None:
cls.__doc__ = _default_instructions(cls)

Expand Down Expand Up @@ -168,24 +179,27 @@ def __repr__(cls):
return f"{cls.__name__}({cls.signature}\n instructions={repr(cls.instructions)}\n {field_repr}\n)"


# A signature for a predictor.
#
# You typically subclass it, like this:
# class MySignature(Signature):
# input: str = InputField(desc="...") # noqa: ERA001
# output: int = OutputField(desc="...") # noqa: ERA001
#
# You can call Signature("input1, input2 -> output1, output2") to create a new signature type.
# You can also include instructions, Signature("input -> output", "This is a test").
# But it's generally better to use the make_signature function.
#
# If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"),
# or a signature, you can use the ensure_signature function.
#
# For compatibility with the legacy dsp format, you can use the signature_to_template function.
#
class Signature(BaseModel, metaclass=SignatureMeta):
"""A signature for a predictor.
You typically subclass it, like this:
class MySignature(Signature):
input: str = InputField(desc="...")
output: int = OutputField(desc="...")
You can call Signature("input1, input2 -> output1, output2") to create a new signature type.
You can also include instructions, Signature("input -> output", "This is a test").
But it's generally better to use the make_signature function.
If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"),
or a signature, you can use the ensure_signature function.
For compatibility with the legacy dsp format, you can use the signature_to_template function.
"""
"" # noqa: D419

# Note: Don't put a docstring here, as it will become the default instructions
# for any signature that doesn't define it's own instructions.
pass


Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import random
from typing import Generic, Literal, TypeVar, Type
import textwrap
from typing import Generic, Literal, TypeVar

import pydantic
import dspy
from dspy.functional.functional import TypedChainOfThought
from dspy.functional.functional import TypedChainOfThought, TypedPredictor
from dspy.signatures import Signature
from dspy import BaseModel
from dspy.signatures.field import InputField, OutputField

# TODO: Consider using the prompt optimizer to optimize the prompt optimizer :O
# TODO:
# - Parallelize the generation of new signatures when we have multiple predictors
# - Consider generating multiple new signatures at once, which we can test in parallel
# - Consider using the prompt optimizer to optimize the prompt optimizer :O


def make_info(signature: type[Signature]) -> BaseModel:
Expand Down Expand Up @@ -55,32 +59,37 @@ def to_signature(info):
# Note: This function wouldn't be necessary if we could make the number of prompts a generic parameter of the class,
# but alas it seems like this isn't possible in Python right now. The main reason being that types and generics only
# live inside the type system, and can't be used to generate code at runtime.
def make_initial_signature(n_prompts: int) -> Type[Signature]:
def make_initial_signature(n_prompts: int) -> type[Signature]:
"""Creates a GenerateInstructionInitial signature with the given number of initial prompts."""

class GenerateInstructionInitial(Signature, Generic[T]):
"""You are a creative instruction optimizer for large language models.
# TODO: Can we make textwrap default/automatic in all signatures?
__doc__ = textwrap.dedent("""\
You are a creative instruction optimizer for large language models.
I will give you a ``signature`` of fields (inputs and outputs) in English.
Your task is to propose variations of the signature that will lead a good language model.
Be very creative and think out of the box. Consider using inspiration such as:
Be very creative and think out of the box.
You can use as long instructions as you want.
Consider using inspiration such as:
Openers:
# You are as smart as ChatGPT.
# You are highly intelligent.
# You are an expert mathematician.
# You are a professor of mathematics.
- You are as smart as ChatGPT.
- You are highly intelligent.
- You are an expert mathematician.
- You are a professor of mathematics.
Task Descriptions:
# Solve the following math question.
# Answer the following math question.
- Be consise in your answer.
- Be as clear as possible.
- Use lots of creativity.
Closers:
# This will be fun!
# Take a deep breath and think carefully.
# I really need your help!
"""
- This will be fun!
- Take a deep breath and think carefully.
- I really need your help!
""")

basic_signature: T = dspy.InputField()
proposed_signatures: list[T] = dspy.OutputField(
basic_signature: T = InputField()
proposed_signatures: list[T] = OutputField(
desc=f"A list of {n_prompts} very different variations of the basic signature",
min_items=n_prompts,
max_items=n_prompts,
Expand All @@ -89,35 +98,31 @@ class GenerateInstructionInitial(Signature, Generic[T]):
return GenerateInstructionInitial


class ScoredSignature(BaseModel, Generic[T]):
signature: T
score: float = dspy.Field(gt=0, lt=100)


class GenerateInstructionGivenAttempts(dspy.Signature, Generic[T]):
"""You are an instruction optimizer for large language models.
class GenerateSignature(dspy.Signature, Generic[T]):
__doc__ = textwrap.dedent("""\
You are an instruction optimizer for large language models.
I will give some task instructions I've tried, along with their corresponding validation scores.
- The instructions are arranged in increasing order based on their scores, where higher scores indicate better quality.
- The instructions are arranged in order based on their scores, where higher scores indicate better quality.
- Your task is to propose a new instruction that will lead a good language model to perform the task even better.
- Be creative, and think out of the box.
- Don't repeat instructions, descriptions and prefixes that have already been attempted.
"""
""")

attempted_signatures: list[ScoredSignature[T]] = dspy.InputField()
proposed_signature: T = dspy.OutputField(desc="Next signature to try")
# expected_score: float = dspy.OutputField(desc="The expected score for the new signature")
analysis: str = OutputField(desc="Consider what made the previous instructions good or bad.")
proposed_signature: T = OutputField(desc="A signature that will likely lead to a high score.")
score: float = OutputField(desc="The expected score for the new signature. Don't write anything after this number.")


def optimize_signature(
student,
evaluator,
n_iterations=10,
strategy: Literal["best", "last"] = "best",
sorted_order: Literal["increasing", "decreasing"] = "increasing",
# Formerly part of the constructor
prompt_model=None,
initial_prompts=2,
temperature=1.4,
verbose=False,
) -> dspy.Program:
"""Create a new program that is optimized for the given task.
Expand All @@ -135,25 +140,39 @@ def optimize_signature(
The number of iterations to run, by default 10
strategy : Literal["best", "last"], optional
The strategy to use to select the final program, by default "best"
sorted_order : Literal["increasing", "decreasing"], optional
The order in which to sort the scores, by default "increasing"
prompt_model : dspy.LanguageModel, optional
The language model to use to generate prompts, by default None
initial_prompts : int, optional
The number of initial prompts to generate, by default 2.
Note that we also use the "plain" signature as a prompt, so the total number of prompts is initial_prompts + 1.
temperature : float, optional
The temperature to use when generating new prompts, by default 1.4
verbose : bool, optional
Whether to print debug information, by default False
Notes:
-----
We don't support temperatures, since it tends to break the typed generation.
"""
if n_iterations < 1 + initial_prompts:
raise ValueError("n_iterations must be at least 1 + initial_prompts")

prompt_model = prompt_model or dspy.settings.lm
MyGenerateInstructionInitial = make_initial_signature(initial_prompts) # noqa: N806

module = student.deepcopy()
# For some reason named_predictors sometimes returns an empty list, so we use named_parameters instead
named_predictors = module.named_parameters()
# In contrast to the original implementation, we don't want the Predict's, but the TypedPredictor's.
# This is because TypedPredictor changes the signature before it runs forward. So changing the signature
# on the Predicts won't work.
named_predictors = [
(name, module)
for name, module in module.named_sub_modules()
if isinstance(module, TypedPredictor) and not getattr(module, "_compiled", False)
]
if not named_predictors:
raise ValueError("No unfrozen/uncompiled TypedPredictors found in the module.")
if verbose:
print("All predictors:")
print(f"{named_predictors=}")
print(f"Found {len(named_predictors)} typed predictors to optimize.")

candidates = {}
scores = []
Expand All @@ -165,45 +184,30 @@ def optimize_signature(
# Make some initial candidates
with dspy.settings.context(lm=prompt_model):
# TODO: Parallelize this
for name, p in named_predictors:
for name, _p in named_predictors:
if verbose:
print(f"Generating new signature for {p}...")
print(f"Generating {initial_prompts} initial signatures for {name}...")
info = candidates[name][0] # Use initial info, to make sure types are identical
generator = TypedChainOfThought(MyGenerateInstructionInitial[type(info)])
candidates[name] += generator(
basic_signature=info,
config={"temperature": temperature},
).proposed_signatures
assert len(candidates[name]) == initial_prompts + 1 # Basic signature + initial prompts

candidates[name] = [
info.model_copy(update={"instructions": info.instructions + f"({i})"})
for i, info in enumerate(candidates[name])
]

for i, c in enumerate(candidates[name]):
print(f"Generated candidate {i}:")
print(c.to_signature())

# Main loop of scoring + generating new candidates
for i in range(n_iterations):
if verbose:
print("\n" + "=" * 80)
print(f"Running eval iteration {i}...")

# Test candidate i
for p in module.predictors():
print(f"Installing signature {i}: ")
print(candidates[name][i].to_signature())
# Install signatures
for name, p in named_predictors:
p.signature = candidates[name][i].to_signature()

# Run evaluator given by user
score = evaluator(module)
score += random.random() * 10
scores.append(score)

if verbose:
print(f"Scores for iteration {i}: {score}")

# If we are still testing initial prompts, continue
if i + 1 < len(next(iter(candidates.values()))):
continue
Expand All @@ -215,25 +219,23 @@ def optimize_signature(
# Otherwise generate the next candidate
with dspy.settings.context(lm=prompt_model):
# TODO: Parallelize this
for name, p in named_predictors:
for name, _p in named_predictors:
SignatureInfo = type(candidates[name][0]) # noqa: N806
generator = TypedChainOfThought(GenerateInstructionGivenAttempts[SignatureInfo])
attempted_signatures = [
ScoredSignature[SignatureInfo](signature=info, score=sc)
generator = TypedPredictor(GenerateSignature[SignatureInfo])

demos = [
dspy.Example(
proposed_signature=info,
score=sc,
)
for info, sc in zip(candidates[name], scores)
]
attempted_signatures.sort(key=lambda x: x.score)
if verbose:
print(
f"Generating new signature for {name} based on {len(attempted_signatures)} previous signatures..."
)
new_signature = generator(
attempted_signatures=attempted_signatures,
config={"temperature": temperature},
).proposed_signature
demos.sort(key=(lambda x: x.score), reverse=(sorted_order == "decreasing"))
generator.predictor.demos = demos

if verbose:
print("Generated candidate:")
print(new_signature.to_signature())
print(f"Generating new signature for {name}...")
new_signature = generator().proposed_signature
candidates[name].append(new_signature)

if strategy == "last":
Expand Down
Loading

0 comments on commit 7283847

Please sign in to comment.