From d51d3629a3c94bfa7735bda8def82b7f77aa8dbb Mon Sep 17 00:00:00 2001 From: Thomas D Ahle Date: Thu, 21 Mar 2024 12:43:26 -0700 Subject: [PATCH] Added test using model_validator --- tests/functional/test_functional.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index c106a13c4..db39e069f 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,7 +1,7 @@ import datetime import textwrap import pydantic -from pydantic import AfterValidator, Field, BaseModel, field_validator +from pydantic import AfterValidator, Field, BaseModel, field_validator, model_validator from typing import Annotated, Generic, Literal, TypeVar from typing import List @@ -802,3 +802,23 @@ def test_conlist2(): make_numbers = TypedPredictor("input:str -> output:Annotated[List[int], Field(min_items=2)]") assert make_numbers(input="What are the first two numbers?").output == [1, 2] + + +def test_model_validator(): + class MySignature(dspy.Signature): + input_data: str = dspy.InputField() + allowed_categories: list[str] = dspy.InputField() + category: str = dspy.OutputField() + + @model_validator(mode="after") + def check_cateogry(self): + if self.category not in self.allowed_categories: + raise ValueError(f"category not in {self.allowed_categories}") + return self + + lm = DummyLM(["horse", "dog"]) + dspy.settings.configure(lm=lm) + predictor = TypedPredictor(MySignature) + + pred = predictor(input_data="What is the best animal?", allowed_categories=["cat", "dog"]) + assert pred.category == "dog"