Skip to content

Commit

Permalink
feat(python): convenience support for parsing a list of SQL strings w…
Browse files Browse the repository at this point in the history
…ith `sql_expr` (pola-rs#9881)
  • Loading branch information
alexander-beedie authored Jul 14, 2023
1 parent cde0be2 commit 147944c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
38 changes: 34 additions & 4 deletions py-polars/polars/functions/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,17 +2093,29 @@ def rolling_corr(
)


def sql_expr(sql: str) -> Expr:
@overload
def sql_expr(sql: str) -> Expr: # type: ignore[misc]
...


@overload
def sql_expr(sql: Sequence[str]) -> list[Expr]:
...


def sql_expr(sql: str | Sequence[str]) -> Expr | list[Expr]:
"""
Parse a SQL expression to a polars expression.
Parse one or more SQL expressions to polars expression(s).
Parameters
----------
sql
SQL expression
One or more SQL expressions.
Examples
--------
Parse a single SQL expression:
>>> df = pl.DataFrame({"a": [2, 1]})
>>> expr = pl.sql_expr("MAX(a)")
>>> df.select(expr)
Expand All @@ -2115,5 +2127,23 @@ def sql_expr(sql: str) -> Expr:
╞═════╡
│ 2 │
└─────┘
Parse multiple SQL expressions:
>>> df.with_columns(
... *pl.sql_expr(["POWER(a,a) AS a_a", "CAST(a AS TEXT) AS a_txt"]),
... )
shape: (2, 3)
┌─────┬─────┬───────┐
│ a ┆ a_a ┆ a_txt │
│ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ str │
╞═════╪═════╪═══════╡
│ 2 ┆ 4.0 ┆ 2 │
│ 1 ┆ 1.0 ┆ 1 │
└─────┴─────┴───────┘
"""
return wrap_expr(plr.sql_expr(sql))
if isinstance(sql, str):
return wrap_expr(plr.sql_expr(sql))
else:
return [wrap_expr(plr.sql_expr(q)) for q in sql]
12 changes: 7 additions & 5 deletions py-polars/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,15 +771,17 @@ def test_register_context() -> None:

def test_sql_expr() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": ["xyz", "abcde", None]})
sql_exprs = (
pl.sql_expr("MIN(a)"),
pl.sql_expr("POWER(a,a) AS aa"),
pl.sql_expr("SUBSTR(b,1,2) AS b2"),
sql_exprs = pl.sql_expr(
[
"MIN(a)",
"POWER(a,a) AS aa",
"SUBSTR(b,1,2) AS b2",
]
)
expected = pl.DataFrame(
{"a": [1, 1, 1], "aa": [1, 4, 27], "b2": ["yz", "bc", None]}
)
assert df.select(sql_exprs).frame_equal(expected)
assert df.select(*sql_exprs).frame_equal(expected)

# expect expressions that can't reasonably be parsed as expressions to raise
# (for example: those that explicitly reference tables and/or use wildcards)
Expand Down

0 comments on commit 147944c

Please sign in to comment.