Skip to content

Commit

Permalink
Better support for string types
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 18, 2024
1 parent a9ee8ed commit ba550d0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 24 deletions.
10 changes: 5 additions & 5 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(self):
self.__dict__[name] = attr.copy()


def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802
def TypedChainOfThought(signature, instructions=None, *, max_retries=3) -> dspy.Module: # noqa: N802
"""Just like TypedPredictor, but adds a ChainOfThought OutputField."""
signature = ensure_signature(signature)
signature = ensure_signature(signature, instructions)
output_keys = ", ".join(signature.output_fields.keys())
return TypedPredictor(
signature.prepend(
Expand All @@ -70,7 +70,7 @@ def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802


class TypedPredictor(dspy.Module):
def __init__(self, signature, max_retries=3, wrap_json=False, explain_errors=False):
def __init__(self, signature, instructions=None, *, max_retries=3, wrap_json=False, explain_errors=False):
"""Like dspy.Predict, but enforces type annotations in the signature.
Args:
Expand All @@ -79,14 +79,14 @@ def __init__(self, signature, max_retries=3, wrap_json=False, explain_errors=Fal
wrap_json: If True, json objects in the input will be wrapped in ```json ... ```
"""
super().__init__()
self.signature = ensure_signature(signature)
self.signature = ensure_signature(signature, instructions)
self.predictor = dspy.Predict(signature)
self.max_retries = max_retries
self.wrap_json = wrap_json
self.explain_errors = explain_errors

def copy(self) -> "TypedPredictor":
return TypedPredictor(self.signature, self.max_retries, self.wrap_json)
return TypedPredictor(self.signature, max_retries=self.max_retries, wrap_json=self.wrap_json)

def __repr__(self):
"""Return a string representation of the TypedPredictor object."""
Expand Down
42 changes: 23 additions & 19 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,13 @@ class Signature(BaseModel, metaclass=SignatureMeta):
pass


def ensure_signature(signature: Union[str, Type[Signature]]) -> Signature:
def ensure_signature(signature: Union[str, Type[Signature]], instructions=None) -> Signature:
if signature is None:
return None
if isinstance(signature, str):
return Signature(signature)
return Signature(signature, instructions)
if instructions is not None:
raise ValueError("Don't specify instructions when initializing with a Signature")
return signature


Expand Down Expand Up @@ -277,27 +279,22 @@ def _parse_signature(signature: str) -> Tuple[Type, Field]:
if signature.count("->") != 1:
raise ValueError(f"Invalid signature format: '{signature}', must contain exactly one '->'.")

inputs_str, outputs_str = signature.split("->")

fields = {}
inputs_str, outputs_str = map(str.strip, signature.split("->"))
inputs = [v.strip() for v in inputs_str.split(",") if v.strip()]
outputs = [v.strip() for v in outputs_str.split(",") if v.strip()]
for name_type in inputs:
name, type_ = _parse_named_type_node(name_type)
for name, type_ in _parse_arg_string(inputs_str):
fields[name] = (type_, InputField())
for name_type in outputs:
name, type_ = _parse_named_type_node(name_type)
for name, type_ in _parse_arg_string(outputs_str):
fields[name] = (type_, OutputField())

return fields


def _parse_named_type_node(node, names=None) -> Any:
parts = node.split(":")
if len(parts) == 1:
return parts[0], str
name, type_str = parts
type_ = _parse_type_node(ast.parse(type_str), names)
return name, type_
def _parse_arg_string(string: str, names=None) -> Dict[str, str]:
args = ast.parse("def f(" + string + "): pass").body[0].args.args
names = [arg.arg for arg in args]
types = [str if arg.annotation is None else _parse_type_node(arg.annotation) for arg in args]
return zip(names, types)


def _parse_type_node(node, names=None) -> Any:
Expand All @@ -306,7 +303,7 @@ def _parse_type_node(node, names=None) -> Any:
without using structural pattern matching introduced in Python 3.10.
"""
if names is None:
names = {}
names = typing.__dict__

if isinstance(node, ast.Module):
body = node.body
Expand All @@ -325,16 +322,23 @@ def _parse_type_node(node, names=None) -> Any:
for type_ in [int, str, float, bool, list, tuple, dict]:
if type_.__name__ == id_:
return type_
raise ValueError(f"Unknown name: {id_}")

elif isinstance(node, ast.Subscript):
if isinstance(node, ast.Subscript):
base_type = _parse_type_node(node.value, names)
arg_type = _parse_type_node(node.slice, names)
return base_type[arg_type]

elif isinstance(node, ast.Tuple):
if isinstance(node, ast.Tuple):
elts = node.elts
return tuple(_parse_type_node(elt, names) for elt in elts)

if isinstance(node, ast.Call):
if node.func.id == "Field":
keys = [kw.arg for kw in node.keywords]
values = [kw.value.value for kw in node.keywords]
return Field(**dict(zip(keys, values)))

raise ValueError(f"Code is not syntactically valid: {node}")


Expand Down
21 changes: 21 additions & 0 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,3 +781,24 @@ def _test_demos_missing_input():
Input: What is the capital of France?
Thoughts: My thoughts
Output: Paris""")


def test_conlist():
dspy.settings.configure(
lm=DummyLM(['{"value": []}', '{"value": [1]}', '{"value": [1, 2]}', '{"value": [1, 2, 3]}'])
)

@predictor
def make_numbers(input: str) -> Annotated[list[int], Field(min_items=2)]:
pass

assert make_numbers(input="What are the first two numbers?") == [1, 2]


def test_conlist2():
dspy.settings.configure(
lm=DummyLM(['{"value": []}', '{"value": [1]}', '{"value": [1, 2]}', '{"value": [1, 2, 3]}'])
)

make_numbers = TypedPredictor("input:str -> output:Annotated[List[int], Field(min_items=2)]")
assert make_numbers(input="What are the first two numbers?").output == [1, 2]

0 comments on commit ba550d0

Please sign in to comment.