Skip to content

Commit

Permalink
attributes all set in bottom of __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkfitzg committed Sep 4, 2015
1 parent f6dc6a2 commit 15babf1
Showing 1 changed file with 44 additions and 35 deletions.
79 changes: 44 additions & 35 deletions xray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,72 +97,81 @@ def __init__(self, data, col=None, row=None, col_wrap=None,
raise ValueError('Coordinates used for faceting cannot '
'contain repeated (nonunique) values.')

# self._single_group is the grouping variable, if there is exactly one
# single_group is the grouping variable, if there is exactly one
if col and row:
self._single_group = False
self._nrow = len(data[row])
self._ncol = len(data[col])
self.nfacet = self._nrow * self._ncol
single_group = False
nrow = len(data[row])
ncol = len(data[col])
nfacet = nrow * ncol
if col_wrap is not None:
warnings.warn('Ignoring col_wrap since both col and row '
'were passed')
elif row and not col:
self._single_group = row
single_group = row
elif not row and col:
self._single_group = col
single_group = col
else:
raise ValueError(
'Pass a coordinate name as an argument for row or col')

# Compute grid shape
if self._single_group:
self.nfacet = len(data[self._single_group])
if single_group:
nfacet = len(data[single_group])
if col:
# idea - could add heuristic for nice shapes like 3x4
self._ncol = self.nfacet
ncol = nfacet
if row:
self._ncol = 1
ncol = 1
if col_wrap is not None:
# Overrides previous settings
self._ncol = col_wrap
self._nrow = int(np.ceil(self.nfacet / self._ncol))
ncol = col_wrap
nrow = int(np.ceil(nfacet / ncol))

# Calculate the base figure size with extra horizontal space for a
# colorbar
self._cbar_space = 1
figsize = (self._ncol * size * aspect +
self._cbar_space, self._nrow * size)
cbar_space = 1
figsize = (ncol * size * aspect +
cbar_space, nrow * size)

self.fig, self.axes = plt.subplots(self._nrow, self._ncol,
sharex=True, sharey=True,
squeeze=False, figsize=figsize)
fig, axes = plt.subplots(nrow, ncol,
sharex=True, sharey=True,
squeeze=False, figsize=figsize)

# Set up the lists of names for the row and column facet variables
col_names = list(data[col].values) if col else []
row_names = list(data[row].values) if row else []

if self._single_group:
full = [{self._single_group: x} for x in
data[self._single_group].values]
empty = [None for x in range(self._nrow * self._ncol - len(full))]
if single_group:
full = [{single_group: x} for x in
data[single_group].values]
empty = [None for x in range(nrow * ncol - len(full))]
name_dicts = full + empty
else:
rowcols = itertools.product(row_names, col_names)
name_dicts = [{row: r, col: c} for r, c in rowcols]

self.name_dicts = np.array(name_dicts).reshape(self._nrow, self._ncol)
name_dicts = np.array(name_dicts).reshape(nrow, ncol)

# Set up the class attributes
# ---------------------------

# First the public API
self.data = data
self.name_dicts = name_dicts
self.fig = fig
self.axes = axes
self.row_names = row_names
self.col_names = col_names
self.data = data
self.row = row
self.col = col
self.col_wrap = col_wrap

# Next the private variables
self._single_group = single_group
self._nrow = nrow
self._row_var = row
self._ncol = ncol
self._col_var = col
self._col_wrap = col_wrap
self._x_var = None
self._y_var = None

self.set_titles()

Expand Down Expand Up @@ -233,16 +242,13 @@ def map_dataarray(self, plotfunc, x, y, **kwargs):
subset = self.data.loc[d]
mappable = plotfunc(subset, x, y, ax=ax, **defaults)

self.x = x
self.y = y

# Left side labels
for ax in self.axes[:, 0]:
ax.set_ylabel(self.y)
ax.set_ylabel(y)

# Bottom labels
for ax in self.axes[-1, :]:
ax.set_xlabel(self.x)
ax.set_xlabel(x)

self.fig.tight_layout()

Expand All @@ -261,6 +267,9 @@ def map_dataarray(self, plotfunc, x, y, **kwargs):
cbar.set_label(self.data.name, rotation=270,
verticalalignment='bottom')

self._x_var = x
self._y_var = y

return self

def set_titles(self, template="{coord} = {value}", maxchar=30,
Expand Down Expand Up @@ -300,14 +309,14 @@ def set_titles(self, template="{coord} = {value}", maxchar=30,
else:
# The row titles on the right edge of the grid
for ax, row_name in zip(self.axes[:, -1], self.row_names):
title = nicetitle(coord=self.row, value=row_name,
title = nicetitle(coord=self._row_var, value=row_name,
maxchar=maxchar)
ax.annotate(title, xy=(1.02, .5), xycoords="axes fraction",
rotation=270, ha="left", va="center", **kwargs)

# The column titles on the top row
for ax, col_name in zip(self.axes[0, :], self.col_names):
title = nicetitle(coord=self.col, value=col_name,
title = nicetitle(coord=self._col_var, value=col_name,
maxchar=maxchar)
ax.set_title(title, **kwargs)

Expand Down

0 comments on commit 15babf1

Please sign in to comment.