diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index d24dd4d9d..048cd9c1a 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -1,10 +1,8 @@ import inspect import json -import os import typing from typing import Annotated, List, Tuple # noqa: UP035 -import openai import pydantic import ujson @@ -13,9 +11,6 @@ from dspy.primitives.prediction import Prediction from dspy.signatures.signature import ensure_signature, make_signature -# Some improvement ideas: -# - Increase the temperature on error - def predictor(func) -> dspy.Module: """Decorator that creates a predictor module based on the provided function.""" @@ -73,27 +68,38 @@ def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802 class TypedPredictor(dspy.Module): - def __init__(self, signature, max_retries=3): + def __init__(self, signature, max_retries=3, wrap_json=False): + """Like dspy.Predict, but enforces type annotations in the signature. + + Args: + signature: The signature of the module. Can use type annotations. + max_retries: The number of times to retry the prediction if the output is invalid. + wrap_json: If True, json objects in the input will be wrapped in ```json ... ``` + """ super().__init__() self.signature = ensure_signature(signature) self.predictor = dspy.Predict(signature) self.max_retries = max_retries + self.wrap_json = wrap_json def copy(self) -> "TypedPredictor": - return TypedPredictor(self.signature, self.max_retries) + return TypedPredictor(self.signature, self.max_retries, self.wrap_json) def __repr__(self): + """Return a string representation of the TypedPredictor object.""" return f"TypedPredictor({self.signature})" - @staticmethod - def _make_example(type_) -> str: + def _make_example(self, type_) -> str: # Note: DSPy will cache this call so we only pay the first time TypedPredictor is called. + schema = json.dumps(type_.model_json_schema()) + if self.wrap_json: + schema = "```json\n" + schema + "\n```\n" json_object = dspy.Predict( make_signature( "json_schema -> json_object", "Make a very succinct json object that validates with the following schema", ), - )(json_schema=json.dumps(type_.model_json_schema())).json_object + )(json_schema=schema).json_object # We use the model_validate_json method to make sure the example is valid try: type_.model_validate_json(_unwrap_json(json_object)) @@ -147,6 +153,9 @@ def _prepare_signature(self) -> dspy.Signature: to_json = lambda x: x.model_dump_json() from_json = lambda x, type_=type_: type_.model_validate_json(x) schema = json.dumps(type_.model_json_schema()) + if self.wrap_json: + to_json = lambda x, inner=to_json: "```json\n" + inner(x) + "\n```\n" + schema = "```json\n" + schema + "\n```" signature = signature.with_updated_fields( name, desc=field.json_schema_extra.get("desc", "") @@ -156,6 +165,7 @@ def _prepare_signature(self) -> dspy.Signature: type_=type_, ) else: # If input field + is_json = False format_ = lambda x: x if isinstance(x, str) else str(x) if type_ in (List[str], list[str], Tuple[str], tuple[str]): format_ = passages2text @@ -168,8 +178,12 @@ def _prepare_signature(self) -> dspy.Signature: ) else: format_ = lambda x: x if isinstance(x, str) else json.dumps(x) + is_json = True elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): format_ = lambda x: x if isinstance(x, str) else x.model_dump_json() + is_json = True + if self.wrap_json and is_json: + format_ = lambda x, inner=format_: x if isinstance(x, str) else "```json\n" + inner(x) + "\n```\n" signature = signature.with_updated_fields(name, format=format_) return signature @@ -215,7 +229,7 @@ def forward(self, **kwargs) -> dspy.Prediction: # Instantiate the actual signature with the parsed values. # This allow pydantic to validate the fields defined in the signature. try: - _dummy = self.signature(**kwargs, **parsed) + _ = self.signature(**kwargs, **parsed) parsed_results.append(parsed) except pydantic.ValidationError as e: errors["general"] = _format_error(e) @@ -223,10 +237,15 @@ def forward(self, **kwargs) -> dspy.Prediction: # Add new fields for each error for name, error in errors.items(): modified_kwargs[f"error_{name}_{try_i}"] = error + if name == "general": + error_prefix = "General:" + else: + error_prefix = signature.output_fields[name].json_schema_extra["prefix"] + number = "" if try_i == 0 else f" ({try_i+1})" signature = signature.append( f"error_{name}_{try_i}", dspy.InputField( - prefix="Past Error " + (f"({name}):" if try_i == 0 else f"({name}, {try_i+1}):"), + prefix=f"Past Error{number} in {error_prefix}", desc="An error to avoid in the future", ), ) @@ -248,14 +267,13 @@ def _format_error(error: Exception): fields = ", ".join(map(str, e["loc"])) errors.append(f"{e['msg']}: {fields} (error type: {e['type']})") return "; ".join(errors) - else: - return repr(error) + return repr(error) def _func_to_signature(func): """Make a dspy.Signature based on a function definition.""" sig = inspect.signature(func) - annotations = typing.get_type_hints(func) + annotations = typing.get_type_hints(func, include_extras=True) output_key = func.__name__ instructions = func.__doc__ fields = {} @@ -268,14 +286,18 @@ def _func_to_signature(func): annotation = annotations.get(param.name, str) kwargs = {} if typing.get_origin(annotation) is Annotated: - annotation, kwargs["desc"] = typing.get_args(annotation) + desc = next((arg for arg in typing.get_args(annotation) if isinstance(arg, str)), None) + if desc is not None: + kwargs["desc"] = desc fields[param.name] = (annotation, dspy.InputField(**kwargs)) # Output field kwargs = {} annotation = annotations.get("return", str) if typing.get_origin(annotation) is Annotated: - annotation, kwargs["desc"] = typing.get_args(annotation) + desc = next((arg for arg in typing.get_args(annotation) if isinstance(arg, str)), None) + if desc is not None: + kwargs["desc"] = desc fields[output_key] = (annotation, dspy.OutputField(**kwargs)) return dspy.Signature(fields, instructions) @@ -287,145 +309,8 @@ def _unwrap_json(output): if not output.startswith("```json"): raise ValueError("json output should start with ```json") if not output.endswith("```"): - raise ValueError("json output should end with ```") + raise ValueError("Don't write anything after the final json ```") output = output[7:-3].strip() if not output.startswith("{") or not output.endswith("}"): raise ValueError("json output should start and end with { and }") return ujson.dumps(ujson.loads(output)) # ujson is a bit more robust than the standard json - - -################################################################################ -# Example usage -################################################################################ - - -def main() -> None: - class Answer(pydantic.BaseModel): - value: float - certainty: float - comments: list[str] = pydantic.Field(description="At least two comments about the answer") - - class QA(dspy.Module): - @predictor - def hard_question(self, topic: str) -> str: - """Think of a hard factual question about a topic. It should be answerable with a number.""" - - @cot - def answer(self, question: Annotated[str, "Question to answer"]) -> Answer: - pass - - def forward(self, **kwargs): - question = self.hard_question(**kwargs) - return (question, self.answer(question=question)) - - openai.api_key = os.getenv("OPENAI_API_KEY") - lm = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000) - # lm = dspy.OpenAI(model="gpt-4", max_tokens=4000) - # lm = dspy.OpenAI(model="gpt-4-preview-1106", max_tokens=4000) - with dspy.context(lm=lm): - qa = QA() - question, answer = qa(topic="Physics") - # lm.inspect_history(n=5) - - print("Question:", question) # noqa: T201 - print("Answer:", answer) # noqa: T201 - - -################################################################################ -# HotpotQA example with SimpleBaleen -################################################################################ - - -def validate_context_and_answer_and_hops(example, pred, trace=None) -> bool: - if not dspy.evaluate.answer_exact_match(example, pred): - return False - if not dspy.evaluate.answer_passage_match(example, pred): - return False - - hops = [example.question] + [outputs.query for *_, outputs in trace if "query" in outputs] - - if max([len(h) for h in hops]) > 100: - return False - if any(dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8) for idx in range(2, len(hops))): - return False - - return True - - -def gold_passages_retrieved(example, pred, _trace=None) -> bool: - gold_titles = set(map(dspy.evaluate.normalize_text, example["gold_titles"])) - found_titles = set(map(dspy.evaluate.normalize_text, [c.split(" | ")[0] for c in pred.context])) - - return gold_titles.issubset(found_titles) - - -def hotpot() -> None: - import dspy.evaluate - from dsp.utils import deduplicate - from dspy.datasets import HotPotQA - from dspy.evaluate.evaluate import Evaluate - from dspy.teleprompt.bootstrap import BootstrapFewShot - - print("Load the dataset.") # noqa: T201 - dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0) - trainset = [x.with_inputs("question") for x in dataset.train] - devset = [x.with_inputs("question") for x in dataset.dev] - print("Done") # noqa: T201 - - class SimplifiedBaleen(FunctionalModule): - def __init__(self, passages_per_hop=3, max_hops=1): - super().__init__() - self.retrieve = dspy.Retrieve(k=passages_per_hop) - self.max_hops = max_hops - - @cot - def generate_query(self, context: list[str], question) -> str: - """Write a simple search query that will help answer a complex question.""" - pass - - @cot - def generate_answer(self, context: list[str], question) -> str: - """Answer questions with short factoid answers.""" - pass - - def forward(self, question): - context = [] - - for _ in range(self.max_hops): - query = self.generate_query(context=context, question=question) - passages = self.retrieve(query).passages - context = deduplicate(context + passages) - - answer = self.generate_answer(context=context, question=question) - return dspy.Prediction(context=context, answer=answer) - - openai.api_key = os.getenv("OPENAI_API_KEY") - rm = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts") - lm = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000) - dspy.settings.configure(lm=lm, rm=rm, trace=[]) - - evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=10, display_progress=True, display_table=5) - - # uncompiled (i.e., zero-shot) program - uncompiled_baleen = SimplifiedBaleen() - print( # noqa: T201 - "Uncompiled Baleen retrieval score:", - evaluate_on_hotpotqa(uncompiled_baleen, metric=gold_passages_retrieved), - ) - - # compiled (i.e., few-shot) program - compiled_baleen = BootstrapFewShot(metric=validate_context_and_answer_and_hops).compile( - SimplifiedBaleen(), - teacher=SimplifiedBaleen(passages_per_hop=2), - trainset=trainset, - ) - print( # noqa: T201 - "Compiled Baleen retrieval score:", - evaluate_on_hotpotqa(compiled_baleen, metric=gold_passages_retrieved), - ) - # lm.inspect_history(n=5) - - -if __name__ == "__main__": - # main() - hotpot() diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 7f8f9237e..f0e22e4be 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,7 +1,7 @@ import datetime import textwrap import pydantic -from pydantic import Field, BaseModel, field_validator +from pydantic import AfterValidator, Field, BaseModel, field_validator from typing import Annotated, Generic, Literal, TypeVar from typing import List @@ -362,9 +362,9 @@ def flight_information(email: str) -> TravelInformation: Email: ${email} - Past Error (flight_information): An error to avoid in the future + Past Error in Flight Information: An error to avoid in the future - Past Error (flight_information, 2): An error to avoid in the future + Past Error (2) in Flight Information: An error to avoid in the future Flight Information: ${flight_information}. Respond with a single JSON object. JSON Schema: {"properties": {"origin": {"pattern": "^[A-Z]{3}$", "title": "Origin", "type": "string"}, "destination": {"pattern": "^[A-Z]{3}$", "title": "Destination", "type": "string"}, "date": {"format": "date", "title": "Date", "type": "string"}}, "required": ["origin", "destination", "date"], "title": "TravelInformation", "type": "object"} @@ -372,9 +372,9 @@ def flight_information(email: str) -> TravelInformation: Email: Some email - Past Error (flight_information): String should match pattern '^[A-Z]{3}$': origin (error type: string_pattern_mismatch) + Past Error in Flight Information: String should match pattern '^[A-Z]{3}$': origin (error type: string_pattern_mismatch) - Past Error (flight_information, 2): String should match pattern '^[A-Z]{3}$': destination (error type: string_pattern_mismatch) + Past Error (2) in Flight Information: String should match pattern '^[A-Z]{3}$': destination (error type: string_pattern_mismatch) Flight Information: {"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}""" ) @@ -418,28 +418,25 @@ def get_user_details() -> UserDetails: Follow the following format. - Past Error (get_user_details): An error to avoid in the future - Past Error (get_user_details, 2): An error to avoid in the future + Past Error in Get User Details: An error to avoid in the future + Past Error (2) in Get User Details: An error to avoid in the future Get User Details: ${get_user_details}. Respond with a single JSON object. JSON Schema: {"properties": {"name": {"title": "Name", "type": "string"}, "age": {"title": "Age", "type": "integer"}}, "required": ["name", "age"], "title": "UserDetails", "type": "object"} --- - Past Error (get_user_details): Value error, Name must be in uppercase.: name (error type: value_error) - Past Error (get_user_details, 2): Value error, Name must be in uppercase.: name (error type: value_error) + Past Error in Get User Details: Value error, Name must be in uppercase.: name (error type: value_error) + Past Error (2) in Get User Details: Value error, Name must be in uppercase.: name (error type: value_error) Get User Details: {"name": "lower case name", "age": 25}""" ) def test_annotated_field(): - # Since we don't currently validate fields on the main signature, - # the annotated fields are also not validated. - # But at least it should not crash. - @predictor def test(input: Annotated[str, Field(description="description")]) -> Annotated[float, Field(gt=0, lt=1)]: pass - lm = DummyLM(["0.5"]) + # First try 0, which fails, then try 0.5, which passes + lm = DummyLM(["0", "0.5"]) dspy.settings.configure(lm=lm) output = test(input="input") @@ -654,3 +651,44 @@ def space_in_a(cls, a: str) -> str: _ = ValidatedSignature(a="no-space") _ = ValidatedSignature(a="with space") + + +def test_annotated_validator(): + def is_square(n: int) -> int: + root = n**0.5 + if not root.is_integer(): + raise ValueError(f"{n} is not a square") + return n + + class MySignature(dspy.Signature): + """What is the next square number after n?""" + + n: int = dspy.InputField() + next_square: Annotated[int, AfterValidator(is_square)] = dspy.OutputField() + + lm = DummyLM(["3", "4"]) + dspy.settings.configure(lm=lm) + + m = TypedPredictor(MySignature)(n=2).next_square + lm.inspect_history(n=2) + + assert m == 4 + + +def test_annotated_validator_functional(): + def is_square(n: int) -> int: + if not (n**0.5).is_integer(): + raise ValueError(f"{n} is not a square") + return n + + @predictor + def next_square(n: int) -> Annotated[int, AfterValidator(is_square)]: + """What is the next square number after n?""" + + lm = DummyLM(["3", "4"]) + dspy.settings.configure(lm=lm) + + m = next_square(n=2) + lm.inspect_history(n=2) + + assert m == 4