Skip to content

Commit

Permalink
Fixed tests for python 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 6, 2024
1 parent baecf8c commit f26c9a0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dspy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import dsp
from dsp.modules.hf_client import ChatModuleClient, HFClientSGLang, HFClientVLLM, HFServerTGI

from .functional import *
from .predict import *
from .primitives import *
from .retrieve import *
from .signatures import *
from .functional import *

settings = dsp.settings

Expand Down
10 changes: 7 additions & 3 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ def _prepare_signature(self) -> dspy.Signature:
schema = json.dumps(type_.model_json_schema())
else:
# Anything else we wrap in a pydantic object
if not (inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel)):
if not (
inspect.isclass(type_)
and typing.get_origin(type_) not in (list, tuple) # To support Python 3.9
and issubclass(type_, pydantic.BaseModel)
):
type_ = pydantic.create_model("Output", value=(type_, ...), __base__=pydantic.BaseModel)
to_json = lambda x, type_=type_: type_(value=x).model_dump_json()
from_json = lambda x, type_=type_: type_.model_validate_json(x).value
Expand All @@ -152,8 +156,6 @@ def _prepare_signature(self) -> dspy.Signature:
format_ = lambda x: x if isinstance(x, str) else str(x)
if type_ in (List[str], list[str], Tuple[str], tuple[str]):
format_ = passages2text
elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
format_ = lambda x: x if isinstance(x, str) else x.model_dump_json()
# Special formatting for lists of known types. Maybe the output fields sohuld have this too?
elif typing.get_origin(type_) in (List, list, Tuple, tuple):
(inner_type,) = typing.get_args(type_)
Expand All @@ -163,6 +165,8 @@ def _prepare_signature(self) -> dspy.Signature:
)
else:
format_ = lambda x: x if isinstance(x, str) else json.dumps(x)
elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
format_ = lambda x: x if isinstance(x, str) else x.model_dump_json()
signature = signature.with_updated_fields(name, format=format_)

return signature
Expand Down

0 comments on commit f26c9a0

Please sign in to comment.