Skip to content

Commit

Permalink
feat(dspy): add a test
Browse files Browse the repository at this point in the history
  • Loading branch information
JONEMI19 committed May 29, 2024
1 parent c237252 commit 2edd705
Showing 1 changed file with 71 additions and 14 deletions.
85 changes: 71 additions & 14 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import datetime
import textwrap
import pydantic
from pydantic import AfterValidator, Field, BaseModel, field_validator, model_validator
from typing import Annotated, Generic, Literal, TypeVar
from typing import List
from typing import Annotated, Any, Generic, List, Literal, Optional, TypeVar

import pydantic
import pytest
from pydantic import AfterValidator, BaseModel, Field, ValidationError, field_validator, model_validator

import dspy
from dspy.functional import predictor, cot, FunctionalModule, TypedPredictor, TypedChainOfThought
from dspy.functional import FunctionalModule, TypedChainOfThought, TypedPredictor, cot, predictor
from dspy.predict.predict import Predict
from dspy.primitives.example import Example
from dspy.teleprompt.bootstrap import BootstrapFewShot
Expand Down Expand Up @@ -282,7 +281,7 @@ def flight_information(email: str) -> TravelInformation:

email = textwrap.dedent(
"""\
We're excited to welcome you aboard your upcoming flight from
We're excited to welcome you aboard your upcoming flight from
John F. Kennedy International Airport (JFK) to Los Angeles International Airport (LAX)
on December 25, 2022. Here's everything you need to know before you take off: ...
"""
Expand All @@ -304,6 +303,56 @@ def flight_information(email: str) -> TravelInformation:
)


def test_custom_model_validate_json():
class TravelInformation(BaseModel):
origin: str = Field(pattern=r"^[A-Z]{3}$")
destination: str = Field(pattern=r"^[A-Z]{3}$")
date: datetime.date

@classmethod
def model_validate_json(
cls, json_data: str, *, strict: Optional[bool] = None, context: Optional[dict[str, Any]] = None
) -> "TravelInformation":
try:
__tracebackhide__ = True
return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)
except ValidationError:
for substring_length in range(len(json_data), 1, -1):
for start in range(len(json_data) - substring_length + 1):
substring = json_data[start : start + substring_length]
try:
__tracebackhide__ = True
res = cls.__pydantic_validator__.validate_json(substring, strict=strict, context=context)
return res
except ValidationError as exc:
last_exc = exc
pass
raise ValueError("Could not find valid json") from last_exc

@predictor
def flight_information(email: str) -> TravelInformation:
pass

email = textwrap.dedent(
"""\
We're excited to welcome you aboard your upcoming flight from
John F. Kennedy International Airport (JFK) to Los Angeles International Airport (LAX)
on December 25, 2022. Here's everything you need to know before you take off: ...
"""
)
lm = DummyLM(
[
# Example with a bad origin code.
'Here is your json: {"origin": "JFK", "destination": "LAX", "date": "2022-12-25"}',
]
)
dspy.settings.configure(lm=lm)

assert flight_information(email=email) == TravelInformation(
origin="JFK", destination="LAX", date=datetime.date(2022, 12, 25)
)


def test_raises():
class TravelInformation(BaseModel):
origin: str = Field(pattern=r"^[A-Z]{3}$")
Expand Down Expand Up @@ -601,7 +650,8 @@ class ScoredSignature(dspy.Signature):

assert output == "Output"

assert lm.get_convo(-1) == textwrap.dedent("""\
assert lm.get_convo(-1) == textwrap.dedent(
"""\
Given the fields `attempted_signatures`, produce the fields `proposed_signature`.
---
Expand All @@ -616,7 +666,8 @@ class ScoredSignature(dspy.Signature):
Attempted Signatures: [{"string":"string 1","score":0.5},{"string":"string 2","score":0.4},{"string":"string 3","score":0.3}]
Reasoning: Let's think step by step in order to Thoughts
Proposed Signature: Output""")
Proposed Signature: Output"""
)


def test_custom_reasoning_field():
Expand All @@ -643,7 +694,8 @@ class QuestionSignature(dspy.Signature):
assert isinstance(output.question, Question)
assert output.question.value == expected

assert lm.get_convo(-1) == textwrap.dedent("""\
assert lm.get_convo(-1) == textwrap.dedent(
"""\
Given the fields `topic`, produce the fields `question`.
---
Expand All @@ -658,7 +710,8 @@ class QuestionSignature(dspy.Signature):
Topic: Physics
Custom Reasoning: Let's break this down. To generate a question about Thoughts
Question: {"value": "What is the speed of light?"}""")
Question: {"value": "What is the speed of light?"}"""
)


def test_generic_signature():
Expand Down Expand Up @@ -772,7 +825,8 @@ def test_demos():

assert program(input="What is the capital of France?").output == "Paris"

assert lm.get_convo(-1) == textwrap.dedent("""\
assert lm.get_convo(-1) == textwrap.dedent(
"""\
Given the fields `input`, produce the fields `output`.
---
Expand All @@ -790,7 +844,8 @@ def test_demos():
---
Input: What is the capital of France?
Output: Paris""")
Output: Paris"""
)


def _test_demos_missing_input():
Expand All @@ -802,7 +857,8 @@ def _test_demos_missing_input():
dspy.settings.configure(lm=DummyLM(["My thoughts", "Paris"]))
assert program(input="What is the capital of France?").output == "Paris"

assert dspy.settings.lm.get_convo(-1) == textwrap.dedent("""\
assert dspy.settings.lm.get_convo(-1) == textwrap.dedent(
"""\
Given the fields `input`, produce the fields `output`.
---
Expand All @@ -822,7 +878,8 @@ def _test_demos_missing_input():
Input: What is the capital of France?
Thoughts: My thoughts
Output: Paris""")
Output: Paris"""
)


def test_conlist():
Expand Down

0 comments on commit 2edd705

Please sign in to comment.