diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index a45fc4aeb..15725aacf 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -53,9 +53,9 @@ def __init__(self): self.__dict__[name] = attr.copy() -def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802 +def TypedChainOfThought(signature, instructions=None, *, max_retries=3) -> dspy.Module: # noqa: N802 """Just like TypedPredictor, but adds a ChainOfThought OutputField.""" - signature = ensure_signature(signature) + signature = ensure_signature(signature, instructions) output_keys = ", ".join(signature.output_fields.keys()) return TypedPredictor( signature.prepend( @@ -70,7 +70,7 @@ def TypedChainOfThought(signature, max_retries=3) -> dspy.Module: # noqa: N802 class TypedPredictor(dspy.Module): - def __init__(self, signature, max_retries=3, wrap_json=False, explain_errors=False): + def __init__(self, signature, instructions=None, *, max_retries=3, wrap_json=False, explain_errors=False): """Like dspy.Predict, but enforces type annotations in the signature. Args: @@ -79,14 +79,14 @@ def __init__(self, signature, max_retries=3, wrap_json=False, explain_errors=Fal wrap_json: If True, json objects in the input will be wrapped in ```json ... ``` """ super().__init__() - self.signature = ensure_signature(signature) + self.signature = ensure_signature(signature, instructions) self.predictor = dspy.Predict(signature) self.max_retries = max_retries self.wrap_json = wrap_json self.explain_errors = explain_errors def copy(self) -> "TypedPredictor": - return TypedPredictor(self.signature, self.max_retries, self.wrap_json) + return TypedPredictor(self.signature, max_retries=self.max_retries, wrap_json=self.wrap_json) def __repr__(self): """Return a string representation of the TypedPredictor object.""" diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 74fa3ee73..d10f71718 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -206,11 +206,13 @@ class Signature(BaseModel, metaclass=SignatureMeta): pass -def ensure_signature(signature: Union[str, Type[Signature]]) -> Signature: +def ensure_signature(signature: Union[str, Type[Signature]], instructions=None) -> Signature: if signature is None: return None if isinstance(signature, str): - return Signature(signature) + return Signature(signature, instructions) + if instructions is not None: + raise ValueError("Don't specify instructions when initializing with a Signature") return signature @@ -277,27 +279,22 @@ def _parse_signature(signature: str) -> Tuple[Type, Field]: if signature.count("->") != 1: raise ValueError(f"Invalid signature format: '{signature}', must contain exactly one '->'.") + inputs_str, outputs_str = signature.split("->") + fields = {} - inputs_str, outputs_str = map(str.strip, signature.split("->")) - inputs = [v.strip() for v in inputs_str.split(",") if v.strip()] - outputs = [v.strip() for v in outputs_str.split(",") if v.strip()] - for name_type in inputs: - name, type_ = _parse_named_type_node(name_type) + for name, type_ in _parse_arg_string(inputs_str): fields[name] = (type_, InputField()) - for name_type in outputs: - name, type_ = _parse_named_type_node(name_type) + for name, type_ in _parse_arg_string(outputs_str): fields[name] = (type_, OutputField()) return fields -def _parse_named_type_node(node, names=None) -> Any: - parts = node.split(":") - if len(parts) == 1: - return parts[0], str - name, type_str = parts - type_ = _parse_type_node(ast.parse(type_str), names) - return name, type_ +def _parse_arg_string(string: str, names=None) -> Dict[str, str]: + args = ast.parse("def f(" + string + "): pass").body[0].args.args + names = [arg.arg for arg in args] + types = [str if arg.annotation is None else _parse_type_node(arg.annotation) for arg in args] + return zip(names, types) def _parse_type_node(node, names=None) -> Any: @@ -306,7 +303,7 @@ def _parse_type_node(node, names=None) -> Any: without using structural pattern matching introduced in Python 3.10. """ if names is None: - names = {} + names = typing.__dict__ if isinstance(node, ast.Module): body = node.body @@ -325,16 +322,23 @@ def _parse_type_node(node, names=None) -> Any: for type_ in [int, str, float, bool, list, tuple, dict]: if type_.__name__ == id_: return type_ + raise ValueError(f"Unknown name: {id_}") - elif isinstance(node, ast.Subscript): + if isinstance(node, ast.Subscript): base_type = _parse_type_node(node.value, names) arg_type = _parse_type_node(node.slice, names) return base_type[arg_type] - elif isinstance(node, ast.Tuple): + if isinstance(node, ast.Tuple): elts = node.elts return tuple(_parse_type_node(elt, names) for elt in elts) + if isinstance(node, ast.Call): + if node.func.id == "Field": + keys = [kw.arg for kw in node.keywords] + values = [kw.value.value for kw in node.keywords] + return Field(**dict(zip(keys, values))) + raise ValueError(f"Code is not syntactically valid: {node}") diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 2bedfa356..c106a13c4 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -781,3 +781,24 @@ def _test_demos_missing_input(): Input: What is the capital of France? Thoughts: My thoughts Output: Paris""") + + +def test_conlist(): + dspy.settings.configure( + lm=DummyLM(['{"value": []}', '{"value": [1]}', '{"value": [1, 2]}', '{"value": [1, 2, 3]}']) + ) + + @predictor + def make_numbers(input: str) -> Annotated[list[int], Field(min_items=2)]: + pass + + assert make_numbers(input="What are the first two numbers?") == [1, 2] + + +def test_conlist2(): + dspy.settings.configure( + lm=DummyLM(['{"value": []}', '{"value": [1]}', '{"value": [1, 2]}', '{"value": [1, 2, 3]}']) + ) + + make_numbers = TypedPredictor("input:str -> output:Annotated[List[int], Field(min_items=2)]") + assert make_numbers(input="What are the first two numbers?").output == [1, 2]