Skip to content

Commit

Permalink
Support RGB[A] arrays in plot.imshow() (pydata#1796)
Browse files Browse the repository at this point in the history
* Allow RGB plots from DataArray.plot.imshow

* Allow RGB[A] dim for imshow to be in any order

Includes new `rgb` keyword to tell imshow about that dimension, and much
error handling in inference.

* Use true RGB color for Rasterio gallery page

* Add whats-new entry
  • Loading branch information
Zac-HD authored and shoyer committed Jan 11, 2018
1 parent 049cbdd commit 289f95a
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 13 deletions.
7 changes: 2 additions & 5 deletions doc/gallery/plot_rasterio.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,10 @@
da.coords['lon'] = (('y', 'x'), lon)
da.coords['lat'] = (('y', 'x'), lat)

# Compute a greyscale out of the rgb image
greyscale = da.mean(dim='band')

# Plot on a map
ax = plt.subplot(projection=ccrs.PlateCarree())
greyscale.plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree(),
cmap='Greys_r', add_colorbar=False)
da.plot.imshow(ax=ax, x='lon', y='lat', rgb='band',
transform=ccrs.PlateCarree())
ax.coastlines('10m', color='r')
plt.show()

Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Enhancements
By `Joe Hamman <https://github.com/jhamman>`_.
- Support for using `Zarr`_ as storage layer for xarray.
By `Ryan Abernathey <https://github.com/rabernat>`_.
- :func:`xarray.plot.imshow` now handles RGB and RGBA images.
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
- Experimental support for parsing ENVI metadata to coordinates and attributes
in :py:func:`xarray.open_rasterio`.
By `Matti Eskelinen <https://github.com/maaleske>`_.
Expand Down
5 changes: 3 additions & 2 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ def map_dataarray(self, func, x, y, **kwargs):
func_kwargs.update({'add_colorbar': False, 'add_labels': False})

# Get x, y labels for the first subplot
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y)
x, y = _infer_xy_labels(
darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y,
imshow=func.__name__ == 'imshow', rgb=kwargs.get('rgb', None))

for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
Expand Down
47 changes: 43 additions & 4 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,17 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
# Decide on a default for the colorbar before facetgrids
if add_colorbar is None:
add_colorbar = plotfunc.__name__ != 'contour'
imshow_rgb = (
plotfunc.__name__ == 'imshow' and
darray.ndim == (3 + (row is not None) + (col is not None)))
if imshow_rgb:
# Don't add a colorbar when showing an image with explicit colors
add_colorbar = False

# Handle facetgrids first
if row or col:
allargs = locals().copy()
allargs.pop('imshow_rgb')
allargs.update(allargs.pop('kwargs'))

# Need the decorated plotting function
Expand All @@ -470,12 +477,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
"Use colors keyword instead.",
DeprecationWarning, stacklevel=3)

xlab, ylab = _infer_xy_labels(darray=darray, x=x, y=y)
rgb = kwargs.pop('rgb', None)
xlab, ylab = _infer_xy_labels(
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb)

if rgb is not None and plotfunc.__name__ != 'imshow':
raise ValueError('The "rgb" keyword is only valid for imshow()')
elif rgb is not None and not imshow_rgb:
raise ValueError('The "rgb" keyword is only valid for imshow()'
'with a three-dimensional array (per facet)')

# better to pass the ndarrays directly to plotting functions
xval = darray[xlab].values
yval = darray[ylab].values
zval = darray.to_masked_array(copy=False)

# check if we need to broadcast one dimension
if xval.ndim < yval.ndim:
Expand All @@ -486,8 +500,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
if darray[xlab].dims[-1] == darray.dims[0]:
zval = zval.T
if imshow_rgb:
# For RGB[A] images, matplotlib requires the color dimension
# to be last. In Xarray the order should be unimportant, so
# we transpose to (y, x, color) to make this work.
yx_dims = (ylab, xlab)
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
if dims != darray.dims:
darray = darray.transpose(*dims)
elif darray[xlab].dims[-1] == darray.dims[0]:
darray = darray.transpose()

# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)

_ensure_plottable(xval, yval)

Expand Down Expand Up @@ -595,6 +620,11 @@ def imshow(x, y, z, ax, **kwargs):
Wraps :func:`matplotlib:matplotlib.pyplot.imshow`
While other plot methods require the DataArray to be strictly
two-dimensional, ``imshow`` also accepts a 3D array where some
dimension can be interpreted as RGB or RGBA color channels and
allows this dimension to be specified via the kwarg ``rgb=``.
.. note::
This function needs uniformly spaced coordinates to
properly label the axes. Call DataArray.plot() to check.
Expand Down Expand Up @@ -632,6 +662,15 @@ def imshow(x, y, z, ax, **kwargs):
# Allow user to override these defaults
defaults.update(kwargs)

if z.ndim == 3:
# matplotlib imshow uses black for missing data, but Xarray makes
# missing data transparent. We therefore add an alpha channel if
# there isn't one, and set it to transparent where data is masked.
if z.shape[-1] == 3:
z = np.ma.concatenate((z, np.ma.ones(z.shape[:2] + (1,))), 2)
z = z.copy()
z[np.any(z.mask, axis=-1), -1] = 0

primitive = ax.imshow(z, **defaults)

return primitive
Expand Down
57 changes: 55 additions & 2 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,65 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
levels=levels, norm=norm)


def _infer_xy_labels(darray, x, y):
def _infer_xy_labels_3d(darray, x, y, rgb):
"""
Determine x and y labels for showing RGB images.
Attempts to infer which dimension is RGB/RGBA by size and order of dims.
"""
assert rgb is None or rgb != x
assert rgb is None or rgb != y
# Start by detecting and reporting invalid combinations of arguments
assert darray.ndim == 3
not_none = [a for a in (x, y, rgb) if a is not None]
if len(set(not_none)) < len(not_none):
raise ValueError(
'Dimension names must be None or unique strings, but imshow was '
'passed x=%r, y=%r, and rgb=%r.' % (x, y, rgb))
for label in not_none:
if label not in darray.dims:
raise ValueError('%r is not a dimension' % (label,))

# Then calculate rgb dimension if certain and check validity
could_be_color = [label for label in darray.dims
if darray[label].size in (3, 4) and label not in (x, y)]
if rgb is None and not could_be_color:
raise ValueError(
'A 3-dimensional array was passed to imshow(), but there is no '
'dimension that could be color. At least one dimension must be '
'of size 3 (RGB) or 4 (RGBA), and not given as x or y.')
if rgb is None and len(could_be_color) == 1:
rgb = could_be_color[0]
if rgb is not None and darray[rgb].size not in (3, 4):
raise ValueError('Cannot interpret dim %r of size %s as RGB or RGBA.'
% (rgb, darray[rgb].size))

# If rgb dimension is still unknown, there must be two or three dimensions
# in could_be_color. We therefore warn, and use a heuristic to break ties.
if rgb is None:
assert len(could_be_color) in (2, 3)
rgb = could_be_color[-1]
warnings.warn(
'Several dimensions of this array could be colors. Xarray '
'will use the last possible dimension (%r) to match '
'matplotlib.pyplot.imshow. You can pass names of x, y, '
'and/or rgb dimensions to override this guess.' % rgb)
assert rgb is not None

# Finally, we pick out the red slice and delegate to the 2D version:
return _infer_xy_labels(darray.isel(**{rgb: 0}).squeeze(), x, y)


def _infer_xy_labels(darray, x, y, imshow=False, rgb=None):
"""
Determine x and y labels. For use in _plot2d
darray must be a 2 dimensional data array.
darray must be a 2 dimensional data array, or 3d for imshow only.
"""
assert x is None or x != y
if imshow and darray.ndim == 3:
return _infer_xy_labels_3d(darray, x, y, rgb)

if x is None and y is None:
if darray.ndim != 2:
Expand Down
53 changes: 53 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,8 @@ def test_1d_raises_valueerror(self):

def test_3d_raises_valueerror(self):
a = DataArray(easy_array((2, 3, 4)))
if self.plotfunc.__name__ == 'imshow':
pytest.skip()
with raises_regex(ValueError, r'DataArray must be 2d'):
self.plotfunc(a)

Expand Down Expand Up @@ -670,6 +672,11 @@ def test_can_plot_axis_size_one(self):
if self.plotfunc.__name__ not in ('contour', 'contourf'):
self.plotfunc(DataArray(np.ones((1, 1))))

def test_disallows_rgb_arg(self):
with pytest.raises(ValueError):
# Always invalid for most plots. Invalid for imshow with 2D data.
self.plotfunc(DataArray(np.ones((2, 2))), rgb='not None')

def test_viridis_cmap(self):
cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
self.assertEqual('viridis', cmap_name)
Expand Down Expand Up @@ -1062,6 +1069,52 @@ def test_2d_coord_names(self):
with raises_regex(ValueError, 'requires 1D coordinates'):
self.plotmethod(x='x2d', y='y2d')

def test_plot_rgb_image(self):
DataArray(
easy_array((10, 15, 3), start=0),
dims=['y', 'x', 'band'],
).plot.imshow()
self.assertEqual(0, len(find_possible_colorbars()))

def test_plot_rgb_image_explicit(self):
DataArray(
easy_array((10, 15, 3), start=0),
dims=['y', 'x', 'band'],
).plot.imshow(y='y', x='x', rgb='band')
self.assertEqual(0, len(find_possible_colorbars()))

def test_plot_rgb_faceted(self):
DataArray(
easy_array((2, 2, 10, 15, 3), start=0),
dims=['a', 'b', 'y', 'x', 'band'],
).plot.imshow(row='a', col='b')
self.assertEqual(0, len(find_possible_colorbars()))

def test_plot_rgba_image_transposed(self):
# We can handle the color axis being in any position
DataArray(
easy_array((4, 10, 15), start=0),
dims=['band', 'y', 'x'],
).plot.imshow()

def test_warns_ambigious_dim(self):
arr = DataArray(easy_array((3, 3, 3)), dims=['y', 'x', 'band'])
with pytest.warns(UserWarning):
arr.plot.imshow()
# but doesn't warn if dimensions specified
arr.plot.imshow(rgb='band')
arr.plot.imshow(x='x', y='y')

def test_rgb_errors_too_many_dims(self):
arr = DataArray(easy_array((3, 3, 3, 3)), dims=['y', 'x', 'z', 'band'])
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')

def test_rgb_errors_bad_dim_sizes(self):
arr = DataArray(easy_array((5, 5, 5)), dims=['y', 'x', 'band'])
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')


class TestFacetGrid(PlotTestCase):

Expand Down

0 comments on commit 289f95a

Please sign in to comment.