Skip to content

Commit

Permalink
fix: Fix cum_count with regards to start value / null values (pola-…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jan 10, 2024
1 parent 43a01a5 commit 683d34d
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 56 deletions.
49 changes: 39 additions & 10 deletions crates/polars-ops/src/series/ops/cum_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,44 @@ pub fn cum_max(s: &Series, reverse: bool) -> PolarsResult<Series> {
}

pub fn cum_count(s: &Series, reverse: bool) -> PolarsResult<Series> {
if reverse {
let ca: NoNull<UInt32Chunked> = (0u32..s.len() as u32).rev().collect();
let mut ca = ca.into_inner();
ca.rename(s.name());
Ok(ca.into_series())
} else {
let ca: NoNull<UInt32Chunked> = (0u32..s.len() as u32).collect();
let mut ca = ca.into_inner();
ca.rename(s.name());
Ok(ca.into_series())
// Fast paths for no nulls
if s.null_count() == 0 {
let out = cum_count_no_nulls(s.name(), s.len(), reverse);
return Ok(out);
}

let ca = s.is_not_null();
let out: IdxCa = if reverse {
let mut count = (s.len() - s.null_count()) as IdxSize;
let mut prev = false;
ca.apply_values_generic(|v: bool| {
if prev {
count -= 1;
}
prev = v;
count
})
} else {
let mut count = 0 as IdxSize;
ca.apply_values_generic(|v: bool| {
if v {
count += 1;
}
count
})
};
Ok(out.into())
}

fn cum_count_no_nulls(name: &str, len: usize, reverse: bool) -> Series {
let start = 1 as IdxSize;
let end = len as IdxSize + 1;
let ca: NoNull<IdxCa> = if reverse {
(start..end).rev().collect()
} else {
(start..end).collect()
};
let mut ca = ca.into_inner();
ca.rename(name);
ca.into_series()
}
26 changes: 12 additions & 14 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,9 +1650,7 @@ def cum_max(self, *, reverse: bool = False) -> Self:

def cum_count(self, *, reverse: bool = False) -> Self:
"""
Get an array with the cumulative count computed at every element.
Counting from 0 to len
Return the cumulative count of the non-null values in the column.
Parameters
----------
Expand All @@ -1661,22 +1659,22 @@ def cum_count(self, *, reverse: bool = False) -> Self:
Examples
--------
>>> df = pl.DataFrame({"a": [1, 2, 3, 4]})
>>> df = pl.DataFrame({"a": ["x", "k", None, "d"]})
>>> df.with_columns(
... pl.col("a").cum_count().alias("cum_count"),
... pl.col("a").cum_count(reverse=True).alias("cum_count_reverse"),
... )
shape: (4, 3)
┌─────┬───────────┬───────────────────┐
│ a ┆ cum_count ┆ cum_count_reverse │
│ --- ┆ --- ┆ --- │
i64 ┆ u32 ┆ u32 │
╞═════╪═══════════╪═══════════════════╡
1 ┆ 0 ┆ 3 │
2 ┆ 1 ┆ 2 │
3 ┆ 2 ┆ 1 │
4 ┆ 3 ┆ 0
└─────┴───────────┴───────────────────┘
┌─────┬───────────┬───────────────────┐
│ a ┆ cum_count ┆ cum_count_reverse │
│ --- ┆ --- ┆ --- │
str ┆ u32 ┆ u32 │
╞═════╪═══════════╪═══════════════════╡
x ┆ 1 ┆ 3 │
k ┆ 2 ┆ 2 │
null ┆ 2 ┆ 1 │
d ┆ 3 ┆ 1
└─────┴───────────┴───────────────────┘
"""
return self._from_pyexpr(self._pyexpr.cum_count(reverse))

Expand Down
20 changes: 17 additions & 3 deletions py-polars/polars/functions/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def count(*columns: str) -> Expr:

def cum_count(*columns: str, reverse: bool = False) -> Expr:
"""
Return the cumulative count of the values in the column or of the context.
Return the cumulative count of the non-null values in the column or of the context.
If no arguments are passed, returns the cumulative count of a context.
Rows containing null values count towards the result.
Expand Down Expand Up @@ -213,7 +213,7 @@ def cum_count(*columns: str, reverse: bool = False) -> Expr:
│ 3 │
└───────────┘
Return the cumulative count of values in a column.
Return the cumulative count of non-null values in a column.
>>> df.select(pl.cum_count("a"))
shape: (3, 1)
Expand All @@ -222,10 +222,24 @@ def cum_count(*columns: str, reverse: bool = False) -> Expr:
│ --- │
│ u32 │
╞═════╡
│ 0 │
│ 1 │
│ 2 │
│ 2 │
└─────┘
Add row numbers to a DataFrame.
>>> df.select(pl.cum_count().alias("row_number"), pl.all())
shape: (3, 3)
┌────────────┬──────┬──────┐
│ row_number ┆ a ┆ b │
│ --- ┆ --- ┆ --- │
│ u32 ┆ i64 ┆ i64 │
╞════════════╪══════╪══════╡
│ 1 ┆ 1 ┆ 3 │
│ 2 ┆ 2 ┆ null │
│ 3 ┆ null ┆ null │
└────────────┴──────┴──────┘
"""
if not columns:
return wrap_expr(plr.cum_count(reverse=reverse))
Expand Down
23 changes: 23 additions & 0 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2976,6 +2976,29 @@ def cum_sum(self, *, reverse: bool = False) -> Series:
]
"""

def cum_count(self, *, reverse: bool = False) -> Self:
"""
Return the cumulative count of the non-null values in the column.
Parameters
----------
reverse
Reverse the operation.
Examples
--------
>>> s = pl.Series(["x", "k", None, "d"])
>>> s.cum_count()
shape: (4,)
Series: '' [u32]
[
1
2
2
3
]
"""

def slice(self, offset: int, length: int | None = None) -> Series:
"""
Get a slice of this Series.
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/benchmark/test_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_windows_not_cached() -> None:
)
.lazy()
.filter(
(pl.col("key").cum_count().over("key") == 0)
(pl.col("key").cum_count().over("key") == 1)
| (pl.col("val").shift(1).over("key").is_not_null())
| (pl.col("val") != pl.col("val").shift(1).over("key"))
)
Expand Down
74 changes: 69 additions & 5 deletions py-polars/tests/unit/functions/test_cum_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal


@pytest.mark.parametrize(("reverse", "output"), [(False, [1, 2, 3]), (True, [3, 2, 1])])
Expand All @@ -17,17 +17,81 @@ def test_cum_count_no_args(reverse: bool, output: list[int]) -> None:
def test_cum_count_single_arg() -> None:
df = pl.DataFrame({"a": [5, 5, None]})
result = df.select(pl.cum_count("a"))
expected = pl.Series("a", [0, 1, 2], dtype=pl.UInt32).to_frame()
expected = pl.Series("a", [1, 2, 2], dtype=pl.UInt32).to_frame()
assert_frame_equal(result, expected)


def test_cum_count_multi_arg() -> None:
df = pl.DataFrame({"a": [5, 5, None], "b": [5, None, None], "c": [1, 2, 3]})
result = df.select(pl.cum_count("a", "b"))
df = pl.DataFrame(
{
"a": [5, 5, 5],
"b": [None, 5, 5],
"c": [5, None, 5],
"d": [5, 5, None],
"e": [None, None, None],
}
)
result = df.select(pl.cum_count("a", "b", "c", "d", "e"))
expected = pl.DataFrame(
[
pl.Series("a", [0, 1, 2], dtype=pl.UInt32),
pl.Series("a", [1, 2, 3], dtype=pl.UInt32),
pl.Series("b", [0, 1, 2], dtype=pl.UInt32),
pl.Series("c", [1, 1, 2], dtype=pl.UInt32),
pl.Series("d", [1, 2, 2], dtype=pl.UInt32),
pl.Series("e", [0, 0, 0], dtype=pl.UInt32),
]
)
assert_frame_equal(result, expected)


def test_cum_count_multi_arg_reverse() -> None:
df = pl.DataFrame(
{
"a": [5, 5, 5],
"b": [None, 5, 5],
"c": [5, None, 5],
"d": [5, 5, None],
"e": [None, None, None],
}
)
result = df.select(pl.cum_count("a", "b", "c", "d", "e", reverse=True))
expected = pl.DataFrame(
[
pl.Series("a", [3, 2, 1], dtype=pl.UInt32),
pl.Series("b", [2, 2, 1], dtype=pl.UInt32),
pl.Series("c", [2, 1, 1], dtype=pl.UInt32),
pl.Series("d", [2, 1, 0], dtype=pl.UInt32),
pl.Series("e", [0, 0, 0], dtype=pl.UInt32),
]
)
assert_frame_equal(result, expected)


def test_cum_count() -> None:
df = pl.DataFrame([["a"], ["a"], ["a"], ["b"], ["b"], ["a"]], schema=["A"])

out = df.group_by("A", maintain_order=True).agg(
pl.col("A").cum_count().alias("foo")
)

assert out["foo"][0].to_list() == [1, 2, 3, 4]
assert out["foo"][1].to_list() == [1, 2]


def test_cumcount_deprecated() -> None:
df = pl.DataFrame([["a"], ["a"], ["a"], ["b"], ["b"], ["a"]], schema=["A"])

with pytest.deprecated_call():
out = df.group_by("A", maintain_order=True).agg(
pl.col("A").cumcount().alias("foo")
)

assert out["foo"][0].to_list() == [1, 2, 3, 4]
assert out["foo"][1].to_list() == [1, 2]


def test_series_cum_count() -> None:
s = pl.Series(["x", "k", None, "d"])
result = s.cum_count()
expected = pl.Series([1, 2, 2, 3], dtype=pl.UInt32)
assert_series_equal(result, expected)
23 changes: 0 additions & 23 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,6 @@ def test_prefix(fruits_cars: pl.DataFrame) -> None:
assert out.columns == ["reverse_A", "reverse_fruits", "reverse_B", "reverse_cars"]


def test_cum_count() -> None:
df = pl.DataFrame([["a"], ["a"], ["a"], ["b"], ["b"], ["a"]], schema=["A"])

out = df.group_by("A", maintain_order=True).agg(
pl.col("A").cum_count().alias("foo")
)

assert out["foo"][0].to_list() == [0, 1, 2, 3]
assert out["foo"][1].to_list() == [0, 1]


def test_cumcount_deprecated() -> None:
df = pl.DataFrame([["a"], ["a"], ["a"], ["b"], ["b"], ["a"]], schema=["A"])

with pytest.deprecated_call():
out = df.group_by("A", maintain_order=True).agg(
pl.col("A").cumcount().alias("foo")
)

assert out["foo"][0].to_list() == [0, 1, 2, 3]
assert out["foo"][1].to_list() == [0, 1]


def test_filter_where() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 1, 2, 3], "b": [4, 5, 6, 7, 8, 9]})
result_filter = df.group_by("a", maintain_order=True).agg(
Expand Down

0 comments on commit 683d34d

Please sign in to comment.