Skip to content

Commit

Permalink
ENH: Plotting support for interval coordinates: groupby_bins (pydata#…
Browse files Browse the repository at this point in the history
…2152)

* ENH: Plotting for groupby_bins

DataArrays created with e.g. groupy_bins have coords containing of pd._libs.interval.Interval. For plotting, the pd._libs.interval.Interval is replaced with the interval's center point. '_center' is appended to teh label

* changed pd._libs.interval.Interval to pd.Interval

* Assign new variable with _interval_to_mid_points instead of mutating original variable.

Note that this changes the the type of  xplt from DataArray to np.array in the line function.

* '_center' added to label only for 1d plot

* added tests

* missing whitespace

* Simplified test

* simplified tests once more

* 1d plots now defaults to step plot

New bool keyword `interval_step_plot` to turn it off.

* non-uniform bin spacing for pcolormesh

* Added step plot function

* bugfix: linestyle == '' results in no line plotted

* Adapted to upstream changes

* Added _resolve_intervals_2dplot function, simplified code

* Added documentation

* typo in documentation

* Fixed bug introduced by upstream change

* Refactor out utility functions.

* Fix test.

* Add whats-new.

* Remove duplicate whats new entry. :/

* Make things neater.
  • Loading branch information
Maximilian Maahn authored and dcherian committed Oct 23, 2018
1 parent 6008dc4 commit 5ebed79
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 24 deletions.
32 changes: 32 additions & 0 deletions doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,41 @@ It is also possible to make line plots such that the data are on the x-axis and
@savefig plotting_example_xy_kwarg.png
air.isel(time=10, lon=[10, 11]).plot(y='lat', hue='lon')
Step plots
~~~~~~~~~~

As an alternative, also a step plot similar to matplotlib's ``plt.step`` can be
made using 1D data.

.. ipython:: python
@savefig plotting_example_step.png width=4in
air1d[:20].plot.step(where='mid')
The argument ``where`` defines where the steps should be placed, options are
``'pre'`` (default), ``'post'``, and ``'mid'``. This is particularly handy
when plotting data grouped with :py:func:`xarray.Dataset.groupby_bins`.

.. ipython:: python
air_grp = air.mean(['time','lon']).groupby_bins('lat',[0,23.5,66.5,90])
air_mean = air_grp.mean()
air_std = air_grp.std()
air_mean.plot.step()
(air_mean + air_std).plot.step(ls=':')
(air_mean - air_std).plot.step(ls=':')
plt.ylim(-20,30)
@savefig plotting_example_step_groupby.png width=4in
plt.title('Zonal mean temperature')
In this case, the actual boundaries of the bins are used and the ``where`` argument
is ignored.


Other axes kwargs
-----------------


The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction.

.. ipython:: python
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ Enhancements
By `Deepak Cherian <https://github.com/dcherian>`_.
- Added support for Python 3.7. (:issue:`2271`).
By `Joe Hamman <https://github.com/jhamman>`_.
- Added support for plotting data with `pandas.Interval` coordinates, such as those
created by :py:meth:`~xarray.DataArray.groupby_bins`
By `Maximilian Maahn <https://github.com/maahn>`_.
- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a
CFTimeIndex by a specified frequency. (:issue:`2244`).
By `Spencer Clark <https://github.com/spencerkclark>`_.
Expand Down
3 changes: 2 additions & 1 deletion xarray/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .plot import (plot, line, contourf, contour,
from .plot import (plot, line, step, contourf, contour,
hist, imshow, pcolormesh)

from .facetgrid import FacetGrid

__all__ = [
'plot',
'line',
'step',
'contour',
'contourf',
'hist',
Expand Down
114 changes: 93 additions & 21 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

from .facetgrid import FacetGrid
from .utils import (
ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels, get_axis,
ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels,
_interval_to_double_bound_points, _interval_to_mid_points,
_resolve_intervals_2dplot, _valid_other_type, get_axis,
import_matplotlib_pyplot, label_from_attrs)


Expand All @@ -36,27 +38,20 @@ def _valid_numpy_subdtype(x, numpy_types):
return any(np.issubdtype(x.dtype, t) for t in numpy_types)


def _valid_other_type(x, types):
"""
Do all elements of x have a type from types?
"""
return all(any(isinstance(el, t) for t in types) for el in np.ravel(x))


def _ensure_plottable(*args):
"""
Raise exception if there is anything in args that can't be plotted on an
axis.
axis by matplotlib.
"""
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64]
other_types = [datetime]

for x in args:
if not (_valid_numpy_subdtype(np.array(x), numpy_types) or
_valid_other_type(np.array(x), other_types)):
if not (_valid_numpy_subdtype(np.array(x), numpy_types)
or _valid_other_type(np.array(x), other_types)):
raise TypeError('Plotting requires coordinates to be numeric '
'or dates of type np.datetime64 or '
'datetime.datetime.')
'datetime.datetime or pd.Interval.')


def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None,
Expand Down Expand Up @@ -350,9 +345,30 @@ def line(darray, *args, **kwargs):
xplt, yplt, hueplt, xlabel, ylabel, huelabel = \
_infer_line_data(darray, x, y, hue)

_ensure_plottable(xplt)
# Remove pd.Intervals if contained in xplt.values.
if _valid_other_type(xplt.values, [pd.Interval]):
# Is it a step plot? (see matplotlib.Axes.step)
if kwargs.get('linestyle', '').startswith('steps-'):
xplt_val, yplt_val = _interval_to_double_bound_points(xplt.values,
yplt.values)
# Remove steps-* to be sure that matplotlib is not confused
kwargs['linestyle'] = (kwargs['linestyle']
.replace('steps-pre', '')
.replace('steps-post', '')
.replace('steps-mid', ''))
if kwargs['linestyle'] == '':
kwargs.pop('linestyle')
else:
xplt_val = _interval_to_mid_points(xplt.values)
yplt_val = yplt.values
xlabel += '_center'
else:
xplt_val = xplt.values
yplt_val = yplt.values

primitive = ax.plot(xplt, yplt, *args, **kwargs)
_ensure_plottable(xplt_val, yplt_val)

primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs)

if _labels:
if xlabel is not None:
Expand Down Expand Up @@ -383,6 +399,46 @@ def line(darray, *args, **kwargs):
return primitive


def step(darray, *args, **kwargs):
"""
Step plot of DataArray index against values
Similar to :func:`matplotlib:matplotlib.pyplot.step`
Parameters
----------
where : {'pre', 'post', 'mid'}, optional, default 'pre'
Define where the steps should be placed:
- 'pre': The y value is continued constantly to the left from
every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the
value ``y[i]``.
- 'post': The y value is continued constantly to the right from
every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the
value ``y[i]``.
- 'mid': Steps occur half-way between the *x* positions.
Note that this parameter is ignored if the x coordinate consists of
:py:func:`pandas.Interval` values, e.g. as a result of
:py:func:`xarray.Dataset.groupby_bins`. In this case, the actual
boundaries of the interval are used.
*args, **kwargs : optional
Additional arguments following :py:func:`xarray.plot.line`
"""
if ('ls' in kwargs.keys()) and ('linestyle' not in kwargs.keys()):
kwargs['linestyle'] = kwargs.pop('ls')

where = kwargs.pop('where', 'pre')

if where not in ('pre', 'post', 'mid'):
raise ValueError("'where' argument to step must be "
"'pre', 'post' or 'mid'")

kwargs['linestyle'] = 'steps-' + where + kwargs.get('linestyle', '')

return line(darray, *args, **kwargs)


def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs):
"""
Histogram of DataArray
Expand Down Expand Up @@ -500,6 +556,10 @@ def hist(self, ax=None, **kwargs):
def line(self, *args, **kwargs):
return line(self._da, *args, **kwargs)

@functools.wraps(step)
def step(self, *args, **kwargs):
return step(self._da, *args, **kwargs)


def _rescale_imshow_rgb(darray, vmin, vmax, robust):
assert robust or vmin is not None or vmax is not None
Expand Down Expand Up @@ -740,7 +800,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)

_ensure_plottable(xval, yval)
# Replace pd.Intervals if contained in xval or yval.
xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__)
yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__)

_ensure_plottable(xplt, yplt)

if 'contour' in plotfunc.__name__ and levels is None:
levels = 7 # this is the matplotlib default
Expand Down Expand Up @@ -780,16 +844,16 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
"in xarray")

ax = get_axis(figsize, size, aspect, ax)
primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'],
primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'],
vmin=cmap_params['vmin'],
vmax=cmap_params['vmax'],
norm=cmap_params['norm'],
**kwargs)

# Label the plot with metadata
if add_labels:
ax.set_xlabel(label_from_attrs(darray[xlab]))
ax.set_ylabel(label_from_attrs(darray[ylab]))
ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra))
ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra))
ax.set_title(darray._title_for_slice())

if add_colorbar:
Expand Down Expand Up @@ -818,7 +882,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
# Do this without calling autofmt_xdate so that x-axes ticks
# on other subplots (if any) are not deleted.
# https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
if np.issubdtype(xval.dtype, np.datetime64):
if np.issubdtype(xplt.dtype, np.datetime64):
for xlabels in ax.get_xticklabels():
xlabels.set_rotation(30)
xlabels.set_ha('right')
Expand Down Expand Up @@ -1019,14 +1083,22 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
else:
infer_intervals = True

if infer_intervals:
if (infer_intervals and
((np.shape(x)[0] == np.shape(z)[1]) or
((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])))):
if len(x.shape) == 1:
x = _infer_interval_breaks(x, check_monotonic=True)
y = _infer_interval_breaks(y, check_monotonic=True)
else:
# we have to infer the intervals on both axes
x = _infer_interval_breaks(x, axis=1)
x = _infer_interval_breaks(x, axis=0)

if (infer_intervals and
(np.shape(y)[0] == np.shape(z)[0])):
if len(y.shape) == 1:
y = _infer_interval_breaks(y, check_monotonic=True)
else:
# we have to infer the intervals on both axes
y = _infer_interval_breaks(y, axis=1)
y = _infer_interval_breaks(y, axis=0)

Expand Down
68 changes: 66 additions & 2 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import absolute_import, division, print_function

import itertools
import textwrap
import warnings

import numpy as np
import pandas as pd

from ..core.options import OPTIONS
from ..core.pycompat import basestring
Expand Down Expand Up @@ -367,7 +369,7 @@ def get_axis(figsize, size, aspect, ax):
return ax


def label_from_attrs(da):
def label_from_attrs(da, extra=''):
''' Makes informative labels if variable metadata (attrs) follows
CF conventions. '''

Expand All @@ -385,4 +387,66 @@ def label_from_attrs(da):
else:
units = ''

return '\n'.join(textwrap.wrap(name + units, 30))
return '\n'.join(textwrap.wrap(name + extra + units, 30))


def _interval_to_mid_points(array):
"""
Helper function which returns an array
with the Intervals' mid points.
"""

return np.array([x.mid for x in array])


def _interval_to_bound_points(array):
"""
Helper function which returns an array
with the Intervals' boundaries.
"""

array_boundaries = np.array([x.left for x in array])
array_boundaries = np.concatenate(
(array_boundaries, np.array([array[-1].right])))

return array_boundaries


def _interval_to_double_bound_points(xarray, yarray):
"""
Helper function to deal with a xarray consisting of pd.Intervals. Each
interval is replaced with both boundaries. I.e. the length of xarray
doubles. yarray is modified so it matches the new shape of xarray.
"""

xarray1 = np.array([x.left for x in xarray])
xarray2 = np.array([x.right for x in xarray])

xarray = list(itertools.chain.from_iterable(zip(xarray1, xarray2)))
yarray = list(itertools.chain.from_iterable(zip(yarray, yarray)))

return xarray, yarray


def _resolve_intervals_2dplot(val, func_name):
"""
Helper function to replace the values of a coordinate array containing
pd.Interval with their mid-points or - for pcolormesh - boundaries which
increases length by 1.
"""
label_extra = ''
if _valid_other_type(val, [pd.Interval]):
if func_name == 'pcolormesh':
val = _interval_to_bound_points(val)
else:
val = _interval_to_mid_points(val)
label_extra = '_center'

return val, label_extra


def _valid_other_type(x, types):
"""
Do all elements of x have a type from types?
"""
return all(any(isinstance(el, t) for t in types) for el in np.ravel(x))
Loading

0 comments on commit 5ebed79

Please sign in to comment.