Skip to content

Commit

Permalink
Use default axis labels and sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkfitzg committed Sep 4, 2015
1 parent 266b0bf commit 29db010
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 40 deletions.
27 changes: 13 additions & 14 deletions xray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def __init__(self, darray, col=None, row=None, col_wrap=None,
def __iter__(self):
return self.axes.flat

def map_dataarray(self, plotfunc, x, y, max_xticks=4, max_yticks=4,
fontsize=_FONTSIZE, **kwargs):
def map_dataarray(self, plotfunc, x, y, **kwargs):
"""
Apply a plotting function to a 2d facet's subset of the data.
Expand All @@ -183,12 +182,6 @@ def map_dataarray(self, plotfunc, x, y, max_xticks=4, max_yticks=4,
plotting method such as `xray.plot.imshow`
x, y : string
Names of the coordinates to plot on x, y axes
max_xticks, max_yticks : int, optional
Maximum number of labeled ticks to plot on x, y axes
max_yticks : int
Maximum number of tick marks to place on y axis
fontsize : string or int
Font size as used by matplotlib text
kwargs :
additional keyword arguments to plotfunc
Expand Down Expand Up @@ -271,10 +264,6 @@ def map_dataarray(self, plotfunc, x, y, max_xticks=4, max_yticks=4,
cbar.set_label(self.darray.name, rotation=270,
verticalalignment='bottom')

# This happens here rather than __init__ since FacetGrid.map should
# use default ticks
self.set_ticks(max_xticks, max_yticks, fontsize)

return self

def set_titles(self, template="{coord} = {value}", maxchar=30,
Expand Down Expand Up @@ -330,9 +319,19 @@ def set_titles(self, template="{coord} = {value}", maxchar=30,
def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS,
fontsize=_FONTSIZE):
"""
Sets tick behavior.
Set and control tick behavior
Parameters
----------
max_xticks, max_yticks : int, optional
Maximum number of labeled ticks to plot on x, y axes
fontsize : string or int
Font size as used by matplotlib text
Returns
-------
self : FacetGrid object
Refer to documentation in :meth:`FacetGrid.map_dataarray`
"""
from matplotlib.ticker import MaxNLocator

Expand Down
50 changes: 24 additions & 26 deletions xray/test/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import unittest

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -207,6 +205,7 @@ def test_plot_nans(self):

@requires_matplotlib
class TestDetermineCmapParams(TestCase):

def setUp(self):
self.data = np.linspace(0, 1, num=100)

Expand Down Expand Up @@ -263,6 +262,7 @@ def test_list_levels(self):

@requires_matplotlib
class TestDiscreteColorMap(TestCase):

def setUp(self):
x = np.arange(start=0, stop=10, step=2)
y = np.arange(start=9, stop=-7, step=-3)
Expand Down Expand Up @@ -347,8 +347,10 @@ class Common2dMixin:
These tests assume that a staticmethod for `self.plotfunc` exists.
Should have the same name as the method.
"""

def setUp(self):
self.darray = DataArray(easy_array((10, 15), start=-1), dims=['y', 'x'])
self.darray = DataArray(easy_array(
(10, 15), start=-1), dims=['y', 'x'])
self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__)

def test_label_names(self):
Expand Down Expand Up @@ -412,7 +414,7 @@ def test_seaborn_palette_as_cmap(self):
try:
import seaborn
cmap_name = self.plotmethod(
levels=2, cmap='husl').get_cmap().name
levels=2, cmap='husl').get_cmap().name
self.assertEqual('husl', cmap_name)
except ImportError:
pass
Expand Down Expand Up @@ -452,12 +454,6 @@ def test_bad_x_string_exception(self):
with self.assertRaisesRegexp(KeyError, r'y'):
self.plotmethod('z')

def test_default_title(self):
a = DataArray(easy_array((4, 3, 2, 1)), dims=['a', 'b', 'c', 'd'])
self.plotfunc(a.isel(c=1))
title = plt.gca().get_title()
self.assertEqual('c = 1, d = 0', title)

def test_default_title(self):
a = DataArray(easy_array((4, 3, 2)), dims=['a', 'b', 'c'])
a.coords['d'] = 10
Expand Down Expand Up @@ -535,21 +531,22 @@ def _color_as_tuple(c):
return tuple(c[:3])
artist = self.plotmethod(colors='k')
self.assertEqual(
_color_as_tuple(artist.cmap.colors[0]),
(0.0,0.0,0.0))
_color_as_tuple(artist.cmap.colors[0]),
(0.0, 0.0, 0.0))

artist = self.plotmethod(colors=['k','b'])
artist = self.plotmethod(colors=['k', 'b'])
self.assertEqual(
_color_as_tuple(artist.cmap.colors[1]),
(0.0,0.0,1.0))
_color_as_tuple(artist.cmap.colors[1]),
(0.0, 0.0, 1.0))

def test_cmap_and_color_both(self):
with self.assertRaises(ValueError):
with self.assertRaises(ValueError):
self.plotmethod(colors='k', cmap='RdBu')

def list_of_colors_in_cmap_deprecated(self):
with self.assertRaises(DeprecationError):
self.plotmethod(cmap=['k','b'])
with self.assertRaises(Exception):
self.plotmethod(cmap=['k', 'b'])


class TestPcolormesh(Common2dMixin, PlotTestCase):

Expand Down Expand Up @@ -656,11 +653,11 @@ def test_colorbar(self):
# There's only one colorbar
cbar = plt.gcf().findobj(mpl.collections.QuadMesh)
self.assertEqual(1, len(cbar))

def test_empty_cell(self):
g = xplt.FacetGrid(self.darray, col='z', col_wrap=2)
g.map_dataarray(xplt.imshow, 'x', 'y')

bottomright = g.axes[-1, -1]
self.assertFalse(bottomright.has_data())
self.assertFalse(bottomright.get_visible())
Expand All @@ -685,7 +682,7 @@ def test_float_index(self):
def test_nonunique_index_error(self):
self.darray.coords['z'] = [0.1, 0.2, 0.2]
with self.assertRaisesRegexp(ValueError, r'[Uu]nique'):
g = xplt.FacetGrid(self.darray, col='z')
xplt.FacetGrid(self.darray, col='z')

def test_robust(self):
z = np.zeros((20, 20, 2))
Expand Down Expand Up @@ -730,8 +727,9 @@ def test_figure_size(self):
def test_num_ticks(self):
nticks = 100
maxticks = nticks + 1
self.g.map_dataarray(xplt.imshow, 'x', 'y', max_xticks=nticks,
max_yticks=nticks)
self.g.map_dataarray(xplt.imshow, 'x', 'y')
self.g.set_ticks(max_xticks=nticks, max_yticks=nticks)

for ax in self.g:
xticks = len(ax.get_xticks())
yticks = len(ax.get_yticks())
Expand All @@ -745,14 +743,14 @@ def test_map(self):


class TestFacetGrid4d(PlotTestCase):

def setUp(self):
a = easy_array((10, 15, 3, 2))
darray = DataArray(a, dims=['y', 'x', 'col', 'row'])
darray.coords['col'] = np.array(['col' + str(x) for x in
darray.coords['col'].values])
darray.coords['col'].values])
darray.coords['row'] = np.array(['row' + str(x) for x in
darray.coords['row'].values])
darray.coords['row'].values])

self.darray = darray

Expand Down

0 comments on commit 29db010

Please sign in to comment.