Skip to content

Commit

Permalink
Fixed but with prefixes and annotated types
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 14, 2024
1 parent 0d4e094 commit 69cfd9f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 169 deletions.
195 changes: 40 additions & 155 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import inspect
import json
import os
import typing
from typing import Annotated, List, Tuple # noqa: UP035

import openai
import pydantic
import ujson

Expand All @@ -13,9 +11,6 @@
from dspy.primitives.prediction import Prediction
from dspy.signatures.signature import ensure_signature, make_signature

# Some improvement ideas:
# - Increase the temperature on error


def predictor(func) -> dspy.Module:
"""Decorator that creates a predictor module based on the provided function."""
Expand Down Expand Up @@ -73,27 +68,38 @@ def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802


class TypedPredictor(dspy.Module):
def __init__(self, signature, max_retries=3):
def __init__(self, signature, max_retries=3, wrap_json=False):
"""Like dspy.Predict, but enforces type annotations in the signature.
Args:
signature: The signature of the module. Can use type annotations.
max_retries: The number of times to retry the prediction if the output is invalid.
wrap_json: If True, json objects in the input will be wrapped in ```json ... ```
"""
super().__init__()
self.signature = ensure_signature(signature)
self.predictor = dspy.Predict(signature)
self.max_retries = max_retries
self.wrap_json = wrap_json

def copy(self) -> "TypedPredictor":
return TypedPredictor(self.signature, self.max_retries)
return TypedPredictor(self.signature, self.max_retries, self.wrap_json)

def __repr__(self):
"""Return a string representation of the TypedPredictor object."""
return f"TypedPredictor({self.signature})"

@staticmethod
def _make_example(type_) -> str:
def _make_example(self, type_) -> str:
# Note: DSPy will cache this call so we only pay the first time TypedPredictor is called.
schema = json.dumps(type_.model_json_schema())
if self.wrap_json:
schema = "```json\n" + schema + "\n```\n"
json_object = dspy.Predict(
make_signature(
"json_schema -> json_object",
"Make a very succinct json object that validates with the following schema",
),
)(json_schema=json.dumps(type_.model_json_schema())).json_object
)(json_schema=schema).json_object
# We use the model_validate_json method to make sure the example is valid
try:
type_.model_validate_json(_unwrap_json(json_object))
Expand Down Expand Up @@ -147,6 +153,9 @@ def _prepare_signature(self) -> dspy.Signature:
to_json = lambda x: x.model_dump_json()
from_json = lambda x, type_=type_: type_.model_validate_json(x)
schema = json.dumps(type_.model_json_schema())
if self.wrap_json:
to_json = lambda x, inner=to_json: "```json\n" + inner(x) + "\n```\n"
schema = "```json\n" + schema + "\n```"
signature = signature.with_updated_fields(
name,
desc=field.json_schema_extra.get("desc", "")
Expand All @@ -156,6 +165,7 @@ def _prepare_signature(self) -> dspy.Signature:
type_=type_,
)
else: # If input field
is_json = False
format_ = lambda x: x if isinstance(x, str) else str(x)
if type_ in (List[str], list[str], Tuple[str], tuple[str]):
format_ = passages2text
Expand All @@ -168,8 +178,12 @@ def _prepare_signature(self) -> dspy.Signature:
)
else:
format_ = lambda x: x if isinstance(x, str) else json.dumps(x)
is_json = True
elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
format_ = lambda x: x if isinstance(x, str) else x.model_dump_json()
is_json = True
if self.wrap_json and is_json:
format_ = lambda x, inner=format_: x if isinstance(x, str) else "```json\n" + inner(x) + "\n```\n"
signature = signature.with_updated_fields(name, format=format_)

return signature
Expand Down Expand Up @@ -215,18 +229,23 @@ def forward(self, **kwargs) -> dspy.Prediction:
# Instantiate the actual signature with the parsed values.
# This allow pydantic to validate the fields defined in the signature.
try:
_dummy = self.signature(**kwargs, **parsed)
_ = self.signature(**kwargs, **parsed)
parsed_results.append(parsed)
except pydantic.ValidationError as e:
errors["general"] = _format_error(e)
if errors:
# Add new fields for each error
for name, error in errors.items():
modified_kwargs[f"error_{name}_{try_i}"] = error
if name == "general":
error_prefix = "General:"
else:
error_prefix = signature.output_fields[name].json_schema_extra["prefix"]
number = "" if try_i == 0 else f" ({try_i+1})"
signature = signature.append(
f"error_{name}_{try_i}",
dspy.InputField(
prefix="Past Error " + (f"({name}):" if try_i == 0 else f"({name}, {try_i+1}):"),
prefix=f"Past Error{number} in {error_prefix}",
desc="An error to avoid in the future",
),
)
Expand All @@ -248,14 +267,13 @@ def _format_error(error: Exception):
fields = ", ".join(map(str, e["loc"]))
errors.append(f"{e['msg']}: {fields} (error type: {e['type']})")
return "; ".join(errors)
else:
return repr(error)
return repr(error)


def _func_to_signature(func):
"""Make a dspy.Signature based on a function definition."""
sig = inspect.signature(func)
annotations = typing.get_type_hints(func)
annotations = typing.get_type_hints(func, include_extras=True)
output_key = func.__name__
instructions = func.__doc__
fields = {}
Expand All @@ -268,14 +286,18 @@ def _func_to_signature(func):
annotation = annotations.get(param.name, str)
kwargs = {}
if typing.get_origin(annotation) is Annotated:
annotation, kwargs["desc"] = typing.get_args(annotation)
desc = next((arg for arg in typing.get_args(annotation) if isinstance(arg, str)), None)
if desc is not None:
kwargs["desc"] = desc
fields[param.name] = (annotation, dspy.InputField(**kwargs))

# Output field
kwargs = {}
annotation = annotations.get("return", str)
if typing.get_origin(annotation) is Annotated:
annotation, kwargs["desc"] = typing.get_args(annotation)
desc = next((arg for arg in typing.get_args(annotation) if isinstance(arg, str)), None)
if desc is not None:
kwargs["desc"] = desc
fields[output_key] = (annotation, dspy.OutputField(**kwargs))

return dspy.Signature(fields, instructions)
Expand All @@ -287,145 +309,8 @@ def _unwrap_json(output):
if not output.startswith("```json"):
raise ValueError("json output should start with ```json")
if not output.endswith("```"):
raise ValueError("json output should end with ```")
raise ValueError("Don't write anything after the final json ```")
output = output[7:-3].strip()
if not output.startswith("{") or not output.endswith("}"):
raise ValueError("json output should start and end with { and }")
return ujson.dumps(ujson.loads(output)) # ujson is a bit more robust than the standard json


################################################################################
# Example usage
################################################################################


def main() -> None:
class Answer(pydantic.BaseModel):
value: float
certainty: float
comments: list[str] = pydantic.Field(description="At least two comments about the answer")

class QA(dspy.Module):
@predictor
def hard_question(self, topic: str) -> str:
"""Think of a hard factual question about a topic. It should be answerable with a number."""

@cot
def answer(self, question: Annotated[str, "Question to answer"]) -> Answer:
pass

def forward(self, **kwargs):
question = self.hard_question(**kwargs)
return (question, self.answer(question=question))

openai.api_key = os.getenv("OPENAI_API_KEY")
lm = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000)
# lm = dspy.OpenAI(model="gpt-4", max_tokens=4000)
# lm = dspy.OpenAI(model="gpt-4-preview-1106", max_tokens=4000)
with dspy.context(lm=lm):
qa = QA()
question, answer = qa(topic="Physics")
# lm.inspect_history(n=5)

print("Question:", question) # noqa: T201
print("Answer:", answer) # noqa: T201


################################################################################
# HotpotQA example with SimpleBaleen
################################################################################


def validate_context_and_answer_and_hops(example, pred, trace=None) -> bool:
if not dspy.evaluate.answer_exact_match(example, pred):
return False
if not dspy.evaluate.answer_passage_match(example, pred):
return False

hops = [example.question] + [outputs.query for *_, outputs in trace if "query" in outputs]

if max([len(h) for h in hops]) > 100:
return False
if any(dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8) for idx in range(2, len(hops))):
return False

return True


def gold_passages_retrieved(example, pred, _trace=None) -> bool:
gold_titles = set(map(dspy.evaluate.normalize_text, example["gold_titles"]))
found_titles = set(map(dspy.evaluate.normalize_text, [c.split(" | ")[0] for c in pred.context]))

return gold_titles.issubset(found_titles)


def hotpot() -> None:
import dspy.evaluate
from dsp.utils import deduplicate
from dspy.datasets import HotPotQA
from dspy.evaluate.evaluate import Evaluate
from dspy.teleprompt.bootstrap import BootstrapFewShot

print("Load the dataset.") # noqa: T201
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0)
trainset = [x.with_inputs("question") for x in dataset.train]
devset = [x.with_inputs("question") for x in dataset.dev]
print("Done") # noqa: T201

class SimplifiedBaleen(FunctionalModule):
def __init__(self, passages_per_hop=3, max_hops=1):
super().__init__()
self.retrieve = dspy.Retrieve(k=passages_per_hop)
self.max_hops = max_hops

@cot
def generate_query(self, context: list[str], question) -> str:
"""Write a simple search query that will help answer a complex question."""
pass

@cot
def generate_answer(self, context: list[str], question) -> str:
"""Answer questions with short factoid answers."""
pass

def forward(self, question):
context = []

for _ in range(self.max_hops):
query = self.generate_query(context=context, question=question)
passages = self.retrieve(query).passages
context = deduplicate(context + passages)

answer = self.generate_answer(context=context, question=question)
return dspy.Prediction(context=context, answer=answer)

openai.api_key = os.getenv("OPENAI_API_KEY")
rm = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")
lm = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000)
dspy.settings.configure(lm=lm, rm=rm, trace=[])

evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=10, display_progress=True, display_table=5)

# uncompiled (i.e., zero-shot) program
uncompiled_baleen = SimplifiedBaleen()
print( # noqa: T201
"Uncompiled Baleen retrieval score:",
evaluate_on_hotpotqa(uncompiled_baleen, metric=gold_passages_retrieved),
)

# compiled (i.e., few-shot) program
compiled_baleen = BootstrapFewShot(metric=validate_context_and_answer_and_hops).compile(
SimplifiedBaleen(),
teacher=SimplifiedBaleen(passages_per_hop=2),
trainset=trainset,
)
print( # noqa: T201
"Compiled Baleen retrieval score:",
evaluate_on_hotpotqa(compiled_baleen, metric=gold_passages_retrieved),
)
# lm.inspect_history(n=5)


if __name__ == "__main__":
# main()
hotpot()
Loading

0 comments on commit 69cfd9f

Please sign in to comment.