Skip to content

Commit

Permalink
perf: make truncate 1.5x faster when every is just a single durat…
Browse files Browse the repository at this point in the history
…ion (and not an expression) (pola-rs#16666)
  • Loading branch information
MarcoGorelli authored Jun 3, 2024
1 parent 534b655 commit 57a5046
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 41 deletions.
90 changes: 54 additions & 36 deletions crates/polars-time/src/truncate.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow::legacy::time_zone::Tz;
use arrow::temporal_conversions::{MILLISECONDS, SECONDS_IN_DAY};
use arrow::temporal_conversions::MILLISECONDS_IN_DAY;
use polars_core::prelude::arity::broadcast_try_binary_elementwise;
use polars_core::prelude::*;
use polars_utils::cache::FastFixedCache;
Expand All @@ -16,7 +16,6 @@ impl PolarsTruncate for DatetimeChunked {
fn truncate(&self, tz: Option<&Tz>, every: &StringChunked, offset: &str) -> PolarsResult<Self> {
let offset: Duration = Duration::parse(offset);
let time_zone = self.time_zone();
let mut duration_cache_opt: Option<FastFixedCache<String, Duration>> = None;

// Let's check if we can use a fastpath...
if every.len() == 1 {
Expand All @@ -42,34 +41,35 @@ impl PolarsTruncate for DatetimeChunked {
})
.into_datetime(self.time_unit(), time_zone.clone()));
} else {
// A sqrt(n) cache is not too small, not too large.
duration_cache_opt =
Some(FastFixedCache::new((every.len() as f64).sqrt() as usize));
duration_cache_opt
.as_mut()
.map(|cache| *cache.insert(every.to_string(), every_parsed));
let w = Window::new(every_parsed, every_parsed, offset);
let out = match self.time_unit() {
TimeUnit::Milliseconds => {
self.try_apply_nonnull_values_generic(|t| w.truncate_ms(t, tz))
},
TimeUnit::Microseconds => {
self.try_apply_nonnull_values_generic(|t| w.truncate_us(t, tz))
},
TimeUnit::Nanoseconds => {
self.try_apply_nonnull_values_generic(|t| w.truncate_ns(t, tz))
},
};
return Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone()));
}
} else {
return Ok(Int64Chunked::full_null(self.name(), self.len())
.into_datetime(self.time_unit(), self.time_zone().clone()));
}
}
let mut duration_cache = match duration_cache_opt {
Some(cache) => cache,
None => FastFixedCache::new((every.len() as f64).sqrt() as usize),
};

// A sqrt(n) cache is not too small, not too large.
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);

let func = match self.time_unit() {
TimeUnit::Nanoseconds => Window::truncate_ns,
TimeUnit::Microseconds => Window::truncate_us,
TimeUnit::Milliseconds => Window::truncate_ms,
};

// TODO: optimize the code below, so it does the following:
// - convert to naive
// - truncate all naively
// - localize, preserving the fold of the original datetime.
// The last step is the non-trivial one. But it should be worth it,
// and faster than the current approach of truncating everything
// as tz-aware.

let out = broadcast_try_binary_elementwise(self, every, |opt_timestamp, opt_every| match (
opt_timestamp,
opt_every,
Expand Down Expand Up @@ -99,26 +99,44 @@ impl PolarsTruncate for DateChunked {
offset: &str,
) -> PolarsResult<Self> {
let offset = Duration::parse(offset);
// A sqrt(n) cache is not too small, not too large.
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);
let out = broadcast_try_binary_elementwise(&self.0, every, |opt_t, opt_every| {
match (opt_t, opt_every) {
(Some(t), Some(every)) => {
const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY;
let every =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));
let out = match every.len() {
1 => {
if let Some(every) = every.get(0) {
let every = Duration::parse(every);
if every.negative {
polars_bail!(ComputeError: "cannot truncate a Date to a negative duration")
}

let w = Window::new(every, every, offset);
Ok(Some(
(w.truncate_ms(MSECS_IN_DAY * t as i64, None)? / MSECS_IN_DAY) as i32,
))
},
_ => Ok(None),
}
});
self.try_apply_nonnull_values_generic(|t| {
Ok((w.truncate_ms(MILLISECONDS_IN_DAY * t as i64, None)?
/ MILLISECONDS_IN_DAY) as i32)
})
} else {
Ok(Int32Chunked::full_null(self.name(), self.len()))
}
},
_ => broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| {
// A sqrt(n) cache is not too small, not too large.
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);
match (opt_t, opt_every) {
(Some(t), Some(every)) => {
let every = *duration_cache
.get_or_insert_with(every, |every| Duration::parse(every));

if every.negative {
polars_bail!(ComputeError: "cannot truncate a Date to a negative duration")
}

let w = Window::new(every, every, offset);
Ok(Some(
(w.truncate_ms(MILLISECONDS_IN_DAY * t as i64, None)?
/ MILLISECONDS_IN_DAY) as i32,
))
},
_ => Ok(None),
}
}),
};
Ok(out?.into_date())
}
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,94 @@
import datetime as dt
from __future__ import annotations

from datetime import date, datetime
from typing import TYPE_CHECKING

import hypothesis.strategies as st
import pytest
from hypothesis import given

import polars as pl
from polars.testing import assert_series_equal

if TYPE_CHECKING:
from polars.type_aliases import TimeUnit


@given(
value=st.datetimes(
min_value=dt.datetime(1000, 1, 1),
max_value=dt.datetime(3000, 1, 1),
min_value=datetime(1000, 1, 1),
max_value=datetime(3000, 1, 1),
),
n=st.integers(min_value=1, max_value=100),
)
def test_truncate_monthly(value: dt.date, n: int) -> None:
def test_truncate_monthly(value: date, n: int) -> None:
result = pl.Series([value]).dt.truncate(f"{n}mo").item()
# manual calculation
total = value.year * 12 + value.month - 1
remainder = total % n
total -= remainder
year, month = (total // 12), ((total % 12) + 1)
expected = dt.datetime(year, month, 1)
expected = datetime(year, month, 1)
assert result == expected


def test_truncate_date() -> None:
# n vs n
df = pl.DataFrame(
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
)
result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"]
expected = pl.Series("a", [None, None, date(2020, 1, 1)])
assert_series_equal(result, expected)

# n vs 1
df = pl.DataFrame(
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
)
result = df.select(pl.col("a").dt.truncate("1mo"))["a"]
expected = pl.Series("a", [date(2020, 1, 1), None, date(2020, 1, 1)])
assert_series_equal(result, expected)

# n vs missing
df = pl.DataFrame(
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
)
result = df.select(pl.col("a").dt.truncate(pl.lit(None, dtype=pl.String)))["a"]
expected = pl.Series("a", [None, None, None], dtype=pl.Date)
assert_series_equal(result, expected)

# 1 vs n
df = pl.DataFrame(
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
)
result = df.select(a=pl.date(2020, 1, 1).dt.truncate(pl.col("b")))["a"]
expected = pl.Series("a", [None, date(2020, 1, 1), date(2020, 1, 1)])
assert_series_equal(result, expected)

# missing vs n
df = pl.DataFrame(
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
)
result = df.select(a=pl.lit(None, dtype=pl.Date).dt.truncate(pl.col("b")))["a"]
expected = pl.Series("a", [None, None, None], dtype=pl.Date)
assert_series_equal(result, expected)


@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
def test_truncate_datetime_simple(time_unit: TimeUnit) -> None:
s = pl.Series([datetime(2020, 1, 2, 6)], dtype=pl.Datetime(time_unit))
result = s.dt.truncate("1mo").item()
assert result == datetime(2020, 1, 1)
result = s.dt.truncate("1d").item()
assert result == datetime(2020, 1, 2)


@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
def test_truncate_datetime_w_expression(time_unit: TimeUnit) -> None:
df = pl.DataFrame(
{"a": [datetime(2020, 1, 2, 6), datetime(2020, 1, 3, 7)], "b": ["1mo", "1d"]},
schema_overrides={"a": pl.Datetime(time_unit)},
)
result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"]
assert result[0] == datetime(2020, 1, 1)
assert result[1] == datetime(2020, 1, 3)

0 comments on commit 57a5046

Please sign in to comment.