Skip to content

Commit

Permalink
Rewrote type parsing to be 3.9 compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 3, 2024
1 parent a021bd2 commit f223579
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,29 +285,38 @@ def _parse_named_type_node(node, names=None) -> Any:
def _parse_type_node(node, names=None) -> Any:
"""Recursively parse an AST node representing a type annotation.
using structural pattern matching introduced in Python 3.10.
without using structural pattern matching introduced in Python 3.10.
"""
if names is None:
names = {}
match node:
case ast.Module(body=body):
if len(body) != 1:
raise ValueError(f"Code is not syntactically valid: {node}")
return _parse_type_node(body[0], names)
case ast.Expr(value=value):
return _parse_type_node(value, names)
case ast.Name(id=id):
if id in names:
return names[id]
for type_ in [int, str, float, bool, list, tuple, dict]:
if type_.__name__ == id:
return type_
case ast.Subscript(value=value, slice=slice):
base_type = _parse_type_node(value, names)
arg_type = _parse_type_node(slice, names)
return base_type[arg_type]
case ast.Tuple(elts=elts):
return tuple(_parse_type_node(elt, names) for elt in elts)

if isinstance(node, ast.Module):
body = node.body
if len(body) != 1:
raise ValueError(f"Code is not syntactically valid: {node}")
return _parse_type_node(body[0], names)

if isinstance(node, ast.Expr):
value = node.value
return _parse_type_node(value, names)

if isinstance(node, ast.Name):
id_ = node.id
if id_ in names:
return names[id_]
for type_ in [int, str, float, bool, list, tuple, dict]:
if type_.__name__ == id_:
return type_

elif isinstance(node, ast.Subscript):
base_type = _parse_type_node(node.value, names)
arg_type = _parse_type_node(node.slice, names)
return base_type[arg_type]

elif isinstance(node, ast.Tuple):
elts = node.elts
return tuple(_parse_type_node(elt, names) for elt in elts)

raise ValueError(f"Code is not syntactically valid: {node}")


Expand Down

0 comments on commit f223579

Please sign in to comment.