Skip to content

Commit

Permalink
Added test using model_validator
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 21, 2024
1 parent 36a72a7 commit d51d362
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tests/functional/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import textwrap
import pydantic
from pydantic import AfterValidator, Field, BaseModel, field_validator
from pydantic import AfterValidator, Field, BaseModel, field_validator, model_validator
from typing import Annotated, Generic, Literal, TypeVar
from typing import List

Expand Down Expand Up @@ -802,3 +802,23 @@ def test_conlist2():

make_numbers = TypedPredictor("input:str -> output:Annotated[List[int], Field(min_items=2)]")
assert make_numbers(input="What are the first two numbers?").output == [1, 2]


def test_model_validator():
class MySignature(dspy.Signature):
input_data: str = dspy.InputField()
allowed_categories: list[str] = dspy.InputField()
category: str = dspy.OutputField()

@model_validator(mode="after")
def check_cateogry(self):
if self.category not in self.allowed_categories:
raise ValueError(f"category not in {self.allowed_categories}")
return self

lm = DummyLM(["horse", "dog"])
dspy.settings.configure(lm=lm)
predictor = TypedPredictor(MySignature)

pred = predictor(input_data="What is the best animal?", allowed_categories=["cat", "dog"])
assert pred.category == "dog"

0 comments on commit d51d362

Please sign in to comment.