Skip to content

Commit

Permalink
Support dt.floor(), dt.ceil() and dt.round() accessors. (pydata#1827)
Browse files Browse the repository at this point in the history
* Support dt.floor(), dt.ceil() and dt.round() accessors.

* Address comments.

* Add dask test + dtype.

* Add docstrings.
  • Loading branch information
dcherian authored and shoyer committed Feb 11, 2018
1 parent 1d32399 commit cbf4921
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 1 deletion.
8 changes: 8 additions & 0 deletions doc/time-series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ the first letters of the corresponding months.

You can use these shortcuts with both Datasets and DataArray coordinates.

In addition, xarray supports rounding operations ``floor``, ``ceil``, and ``round``. These operations require that you supply a `rounding frequency as a string argument.`__

__ http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases

.. ipython:: python
ds['time'].dt.floor('D')
.. _resampling:

Resampling and grouped operations
Expand Down
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ Enhancements
(:pull:`1840`), and keeping float16 and float32 as float32 (:issue:`1842`).
Correspondingly, encoded variables may also be saved with a smaller dtype.
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
- `.dt` accessor can now ceil, floor and round timestamps to specified frequency.
By `Deepak Cherian <https://github.com/dcherian>`_.

.. _Zarr: http://zarr.readthedocs.io/

Expand All @@ -94,7 +96,7 @@ Bug fixes
~~~~~~~~~
- Added warning in api.py of a netCDF4 bug that occurs when
the filepath has 88 characters (:issue:`1745`).
By `Liam Brannigan <https://github.com/braaannigan>` _.
By `Liam Brannigan <https://github.com/braaannigan>`_.
- Fixed encoding of multi-dimensional coordinates in
:py:meth:`~Dataset.to_netcdf` (:issue:`1763`).
By `Mike Neish <https://github.com/neishm>`_.
Expand Down
92 changes: 92 additions & 0 deletions xarray/core/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,43 @@ def _get_date_field(values, name, dtype):
return _access_through_series(values, name)


def _round_series(values, name, freq):
"""Coerce an array of datetime-like values to a pandas Series and
apply requested rounding
"""
values_as_series = pd.Series(values.ravel())
method = getattr(values_as_series.dt, name)
field_values = method(freq=freq).values

return field_values.reshape(values.shape)


def _round_field(values, name, freq):
"""Indirectly access pandas rounding functions by wrapping data
as a Series and calling through `.dt` attribute.
Parameters
----------
values : np.ndarray or dask.array-like
Array-like container of datetime-like values
name : str (ceil, floor, round)
Name of rounding function
freq : a freq string indicating the rounding resolution
Returns
-------
rounded timestamps : same type as values
Array-like of datetime fields accessed for each element in values
"""
if isinstance(values, dask_array_type):
from dask.array import map_blocks
return map_blocks(_round_series,
values, name, freq=freq, dtype=np.datetime64)
else:
return _round_series(values, name, freq)


class DatetimeAccessor(object):
"""Access datetime fields for DataArrays with datetime-like dtypes.
Expand Down Expand Up @@ -147,3 +184,58 @@ def f(self, dtype=dtype):
time = _tslib_field_accessor(
"time", "Timestamps corresponding to datetimes", object
)

def _tslib_round_accessor(self, name, freq):
obj_type = type(self._obj)
result = _round_field(self._obj.data, name, freq)
return obj_type(result, name=name,
coords=self._obj.coords, dims=self._obj.dims)

def floor(self, freq):
'''
Round timestamps downward to specified frequency resolution.
Parameters
----------
freq : a freq string indicating the rounding resolution
e.g. 'D' for daily resolution
Returns
-------
floor-ed timestamps : same type as values
Array-like of datetime fields accessed for each element in values
'''

return self._tslib_round_accessor("floor", freq)

def ceil(self, freq):
'''
Round timestamps upward to specified frequency resolution.
Parameters
----------
freq : a freq string indicating the rounding resolution
e.g. 'D' for daily resolution
Returns
-------
ceil-ed timestamps : same type as values
Array-like of datetime fields accessed for each element in values
'''
return self._tslib_round_accessor("ceil", freq)

def round(self, freq):
'''
Round timestamps to specified frequency resolution.
Parameters
----------
freq : a freq string indicating the rounding resolution
e.g. 'D' for daily resolution
Returns
-------
rounded timestamps : same type as values
Array-like of datetime fields accessed for each element in values
'''
return self._tslib_round_accessor("round", freq)
20 changes: 20 additions & 0 deletions xarray/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def test_dask_field_access(self):
months = self.times_data.dt.month
hours = self.times_data.dt.hour
days = self.times_data.dt.day
floor = self.times_data.dt.floor('D')
ceil = self.times_data.dt.ceil('D')
round = self.times_data.dt.round('D')

dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50))
dask_times_2d = xr.DataArray(dask_times_arr,
Expand All @@ -67,6 +70,9 @@ def test_dask_field_access(self):
dask_month = dask_times_2d.dt.month
dask_day = dask_times_2d.dt.day
dask_hour = dask_times_2d.dt.hour
dask_floor = dask_times_2d.dt.floor('D')
dask_ceil = dask_times_2d.dt.ceil('D')
dask_round = dask_times_2d.dt.round('D')

# Test that the data isn't eagerly evaluated
assert isinstance(dask_year.data, da.Array)
Expand All @@ -86,6 +92,9 @@ def test_dask_field_access(self):
assert_equal(months, dask_month.compute())
assert_equal(days, dask_day.compute())
assert_equal(hours, dask_hour.compute())
assert_equal(floor, dask_floor.compute())
assert_equal(ceil, dask_ceil.compute())
assert_equal(round, dask_round.compute())

def test_seasons(self):
dates = pd.date_range(start="2000/01/01", freq="M", periods=12)
Expand All @@ -95,3 +104,14 @@ def test_seasons(self):
seasons = xr.DataArray(seasons)

assert_array_equal(seasons.values, dates.dt.season.values)

def test_rounders(self):
dates = pd.date_range("2014-01-01", "2014-05-01", freq='H')
xdates = xr.DataArray(np.arange(len(dates)),
dims=['time'], coords=[dates])
assert_array_equal(dates.floor('D').values,
xdates.time.dt.floor('D').values)
assert_array_equal(dates.ceil('D').values,
xdates.time.dt.ceil('D').values)
assert_array_equal(dates.round('D').values,
xdates.time.dt.round('D').values)

0 comments on commit cbf4921

Please sign in to comment.