Skip to content

Commit

Permalink
fix: Fix scatter for null values (pola-rs#13578)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jan 10, 2024
1 parent f4401fb commit 2b43fc1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
14 changes: 13 additions & 1 deletion crates/polars-core/src/chunked_array/ops/chunkops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,24 @@ impl<T: PolarsDataType> ChunkedArray<T> {
self.length as usize
}

/// Count the null values.
/// Return the number of null values in the ChunkedArray.
#[inline]
pub fn null_count(&self) -> usize {
self.null_count as usize
}

/// Set the null count directly.
///
/// This can be useful after mutably adjusting the validity of the
/// underlying arrays.
///
/// # Safety
/// The new null count must match the total null count of the underlying
/// arrays.
pub unsafe fn set_null_count(&mut self, null_count: IdxSize) {
self.null_count = null_count;
}

/// Check if ChunkedArray is empty.
pub fn is_empty(&self) -> bool {
self.len() == 0
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-ops/src/chunked_array/scatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ where
arr.set_values(new_values.into());
},
};

// The null count may have changed - make sure to update the ChunkedArray
let new_null_count = arr.null_count();
unsafe { ca.set_null_count(new_null_count.try_into().unwrap()) };

Ok(ca.into_series())
}
}
Expand Down
9 changes: 8 additions & 1 deletion py-polars/tests/unit/series/test_scatter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import date, datetime

import numpy as np
import pytest
Expand Down Expand Up @@ -66,3 +66,10 @@ def test_scatter_datetime() -> None:
result = s.scatter(0, datetime(2022, 2, 2))
expected = pl.Series("dt", [datetime(2022, 2, 2), datetime(2024, 1, 31)])
assert_series_equal(result, expected)


def test_scatter_logical_all_null() -> None:
s = pl.Series("dt", [None, None], dtype=pl.Date)
result = s.scatter(0, date(2022, 2, 2))
expected = pl.Series("dt", [date(2022, 2, 2), None])
assert_series_equal(result, expected)

0 comments on commit 2b43fc1

Please sign in to comment.