Skip to content

Commit

Permalink
Fix TIMESTAMP / DATE scalars, add support for DATE column casting (da…
Browse files Browse the repository at this point in the history
…sk-contrib#343)

* Cast DATEs to pd.Timestamps instead of datetime.datetimes

* Add datetime casting / filtering tests

* Add explicit support for DATE casting

* We need to return None for cast_column_to_type

* Use strftime in place of date attribute

* Use floor('D') instead of strftime for date truncation
  • Loading branch information
charlesbluca authored Feb 14, 2022
1 parent 34fee74 commit f601325
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 16 deletions.
15 changes: 8 additions & 7 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from datetime import datetime, timedelta, timezone
from datetime import timedelta
from typing import Any

import dask.array as da
Expand Down Expand Up @@ -77,7 +77,9 @@
"VARCHAR": pd.StringDtype(),
"CHAR": pd.StringDtype(),
"STRING": pd.StringDtype(), # Although not in the standard, makes compatibility easier
"DATE": np.dtype("<M8[ns]"),
"DATE": np.dtype(
"<M8[ns]"
), # TODO: ideally this would be np.dtype("<M8[D]") but that doesn't work for Pandas
"TIMESTAMP": np.dtype("<M8[ns]"),
"NULL": type(None),
}
Expand Down Expand Up @@ -160,12 +162,11 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any:
tz = literal_value.getTimeZone().getID()
assert str(tz) == "UTC", "The code can currently only handle UTC timezones"

dt = datetime.fromtimestamp(
int(literal_value.getTimeInMillis()) / 1000, timezone.utc
)

return dt
dt = np.datetime64(literal_value.getTimeInMillis(), "ms")

if sql_type == "DATE":
return dt.astype("<M8[D]")
return dt.astype("<M8[ns]")
elif sql_type.startswith("DECIMAL("):
# We use np.float64 always, even though we might
# be able to use a smaller type
Expand Down
16 changes: 11 additions & 5 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,20 @@ def cast(self, operand, rex=None) -> SeriesOrScalar:
return operand

output_type = str(rex.getType())
output_type = sql_to_python_type(output_type.upper())
python_type = sql_to_python_type(output_type.upper())

return_column = cast_column_to_type(operand, output_type)
return_column = cast_column_to_type(operand, python_type)

if return_column is None:
return operand
else:
return return_column
return_column = operand

# TODO: ideally we don't want to directly access the datetimes,
# but Pandas can't truncate timezone datetimes and cuDF can't
# truncate datetimes
if output_type == "DATE":
return return_column.dt.floor("D").astype(python_type)

return return_column


class IsFalseOperation(Operation):
Expand Down
15 changes: 12 additions & 3 deletions tests/integration/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ def datetime_table():
return pd.DataFrame(
{
"timezone": pd.date_range(
start="2014-08-01 09:00", freq="H", periods=3, tz="Europe/Berlin"
start="2014-08-01 09:00", freq="8H", periods=6, tz="Europe/Berlin"
),
"no_timezone": pd.date_range(
start="2014-08-01 09:00", freq="8H", periods=6
),
"no_timezone": pd.date_range(start="2014-08-01 09:00", freq="H", periods=3),
"utc_timezone": pd.date_range(
start="2014-08-01 09:00", freq="H", periods=3, tz="UTC"
start="2014-08-01 09:00", freq="8H", periods=6, tz="UTC"
),
}
)
Expand All @@ -116,6 +118,11 @@ def gpu_string_table(string_table):
return cudf.from_pandas(string_table) if cudf else None


@pytest.fixture()
def gpu_datetime_table(datetime_table):
return cudf.from_pandas(datetime_table) if cudf else None


@pytest.fixture()
def c(
df_simple,
Expand All @@ -131,6 +138,7 @@ def c(
gpu_df,
gpu_long_table,
gpu_string_table,
gpu_datetime_table,
):
dfs = {
"df_simple": df_simple,
Expand All @@ -146,6 +154,7 @@ def c(
"gpu_df": gpu_df,
"gpu_long_table": gpu_long_table,
"gpu_string_table": gpu_string_table,
"gpu_datetime_table": gpu_datetime_table,
}

# Lazy import, otherwise the pytest framework has problems
Expand Down
44 changes: 43 additions & 1 deletion tests/integration/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dask.dataframe as dd
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal

from dask_sql._compat import INT_NAN_IMPLEMENTED
Expand Down Expand Up @@ -70,7 +72,47 @@ def test_string_filter(c, string_table):
)


def test_filter_datetime(c):
@pytest.mark.parametrize(
"input_table",
["datetime_table", pytest.param("gpu_datetime_table", marks=pytest.mark.gpu),],
)
def test_filter_cast_date(c, input_table, request):
datetime_table = request.getfixturevalue(input_table)
return_df = c.sql(
f"""
SELECT * FROM {input_table} WHERE
CAST(timezone AS DATE) > DATE '2014-08-01'
"""
)

expected_df = datetime_table[
datetime_table["timezone"].astype("<M8[ns]").dt.floor("D").astype("<M8[ns]")
> pd.Timestamp("2014-08-01")
]
dd.assert_eq(return_df, expected_df)


@pytest.mark.parametrize(
"input_table",
["datetime_table", pytest.param("gpu_datetime_table", marks=pytest.mark.gpu),],
)
def test_filter_cast_timestamp(c, input_table, request):
datetime_table = request.getfixturevalue(input_table)
return_df = c.sql(
f"""
SELECT * FROM {input_table} WHERE
CAST(timezone AS TIMESTAMP) >= TIMESTAMP '2014-08-01 23:00:00'
"""
)

expected_df = datetime_table[
datetime_table["timezone"].astype("<M8[ns]")
>= pd.Timestamp("2014-08-01 23:00:00")
]
dd.assert_eq(return_df, expected_df)


def test_filter_year(c):
df = pd.DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]})

df["dt"] = pd.to_datetime(df)
Expand Down
51 changes: 51 additions & 0 deletions tests/integration/test_select.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dask.dataframe as dd
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -118,6 +119,56 @@ def test_timezones(c, datetime_table):
assert_frame_equal(result_df, datetime_table)


@pytest.mark.parametrize(
"input_table",
["datetime_table", pytest.param("gpu_datetime_table", marks=pytest.mark.gpu),],
)
def test_date_casting(c, input_table, request):
datetime_table = request.getfixturevalue(input_table)
result_df = c.sql(
f"""
SELECT
CAST(timezone AS DATE) AS timezone,
CAST(no_timezone AS DATE) AS no_timezone,
CAST(utc_timezone AS DATE) AS utc_timezone
FROM {input_table}
"""
)

expected_df = datetime_table
expected_df["timezone"] = (
expected_df["timezone"].astype("<M8[ns]").dt.floor("D").astype("<M8[ns]")
)
expected_df["no_timezone"] = (
expected_df["no_timezone"].astype("<M8[ns]").dt.floor("D").astype("<M8[ns]")
)
expected_df["utc_timezone"] = (
expected_df["utc_timezone"].astype("<M8[ns]").dt.floor("D").astype("<M8[ns]")
)

dd.assert_eq(result_df, expected_df)


@pytest.mark.parametrize(
"input_table",
["datetime_table", pytest.param("gpu_datetime_table", marks=pytest.mark.gpu),],
)
def test_timestamp_casting(c, input_table, request):
datetime_table = request.getfixturevalue(input_table)
result_df = c.sql(
f"""
SELECT
CAST(timezone AS TIMESTAMP) AS timezone,
CAST(no_timezone AS TIMESTAMP) AS no_timezone,
CAST(utc_timezone AS TIMESTAMP) AS utc_timezone
FROM {input_table}
"""
)

expected_df = datetime_table.astype("<M8[ns]")
dd.assert_eq(result_df, expected_df)


def test_multi_case_when(c):
df = pd.DataFrame({"a": [1, 6, 7, 8, 9]})
c.create_table("df", df)
Expand Down

0 comments on commit f601325

Please sign in to comment.