Skip to content

Commit

Permalink
Changes to plotting scatter matrix diagonals
Browse files Browse the repository at this point in the history
  • Loading branch information
orbitfold authored and wesm committed May 19, 2012
1 parent 3e496ed commit 3d5990f
Showing 1 changed file with 68 additions and 54 deletions.
122 changes: 68 additions & 54 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import izip

import numpy as np
from scipy import stats

from pandas.util.decorators import cache_readonly
import pandas.core.common as com
Expand All @@ -12,7 +13,7 @@
from pandas.tseries.offsets import DateOffset

def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
**kwds):
diagonal='hist', **kwds):
"""
Draw a matrix of scatter plots.
Expand All @@ -36,64 +37,77 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,

for i, a in zip(range(n), df.columns):
for j, b in zip(range(n), df.columns):
axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds)
axes[i, j].set_xlabel('')
axes[i, j].set_ylabel('')
axes[i, j].set_xticklabels([])
axes[i, j].set_yticklabels([])
ticks = df.index

is_datetype = ticks.inferred_type in ('datetime', 'date',
if i == j:
# Deal with the diagonal by drawing a histogram there.
if diagonal == 'hist':
axes[i, j].hist(df[a])
elif diagonal == 'kde':
y = df[a]
gkde = stats.gaussian_kde(y)
ind = np.linspace(min(y), max(y), 1000)
axes[i, j].plot(ind, gkde.evaluate(ind), **kwds)
axes[i, j].yaxis.set_visible(False)
axes[i, j].xaxis.set_visible(False)
if i == 0 and j == 0:
axes[i, j].yaxis.set_ticks_position('left')
axes[i, j].yaxis.set_label_position('left')
axes[i, j].yaxis.set_visible(True)
if i == n - 1 and j == n - 1:
axes[i, j].yaxis.set_ticks_position('right')
axes[i, j].yaxis.set_label_position('right')
axes[i, j].yaxis.set_visible(True)
else:
axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds)
axes[i, j].set_xlabel('')
axes[i, j].set_ylabel('')
axes[i, j].set_xticklabels([])
axes[i, j].set_yticklabels([])
ticks = df.index

is_datetype = ticks.inferred_type in ('datetime', 'date',
'datetime64')

if ticks.is_numeric() or is_datetype:
"""
Matplotlib supports numeric values or datetime objects as
xaxis values. Taking LBYL approach here, by the time
matplotlib raises exception when using non numeric/datetime
values for xaxis, several actions are already taken by plt.
"""
ticks = ticks._mpl_repr()

# setup labels
if i == 0 and j % 2 == 1:
axes[i, j].set_xlabel(b, visible=True)
#axes[i, j].xaxis.set_visible(True)
axes[i, j].set_xlabel(b)
axes[i, j].set_xticklabels(ticks)
axes[i, j].xaxis.set_ticks_position('top')
axes[i, j].xaxis.set_label_position('top')
if i == n - 1 and j % 2 == 0:
axes[i, j].set_xlabel(b, visible=True)
#axes[i, j].xaxis.set_visible(True)
axes[i, j].set_xlabel(b)
axes[i, j].set_xticklabels(ticks)
axes[i, j].xaxis.set_ticks_position('bottom')
axes[i, j].xaxis.set_label_position('bottom')
if j == 0 and i % 2 == 0:
axes[i, j].set_ylabel(a, visible=True)
#axes[i, j].yaxis.set_visible(True)
axes[i, j].set_ylabel(a)
axes[i, j].set_yticklabels(ticks)
axes[i, j].yaxis.set_ticks_position('left')
axes[i, j].yaxis.set_label_position('left')
if j == n - 1 and i % 2 == 1:
axes[i, j].set_ylabel(a, visible=True)
#axes[i, j].yaxis.set_visible(True)
axes[i, j].set_ylabel(a)
axes[i, j].set_yticklabels(ticks)
axes[i, j].yaxis.set_ticks_position('right')
axes[i, j].yaxis.set_label_position('right')
if ticks.is_numeric() or is_datetype:
"""
Matplotlib supports numeric values or datetime objects as
xaxis values. Taking LBYL approach here, by the time
matplotlib raises exception when using non numeric/datetime
values for xaxis, several actions are already taken by plt.
"""
ticks = ticks._mpl_repr()

# setup labels
if i == 0 and j % 2 == 1:
axes[i, j].set_xlabel(b, visible=True)
#axes[i, j].xaxis.set_visible(True)
axes[i, j].set_xlabel(b)
axes[i, j].set_xticklabels(ticks)
axes[i, j].xaxis.set_ticks_position('top')
axes[i, j].xaxis.set_label_position('top')
if i == n - 1 and j % 2 == 0:
axes[i, j].set_xlabel(b, visible=True)
#axes[i, j].xaxis.set_visible(True)
axes[i, j].set_xlabel(b)
axes[i, j].set_xticklabels(ticks)
axes[i, j].xaxis.set_ticks_position('bottom')
axes[i, j].xaxis.set_label_position('bottom')
if j == 0 and i % 2 == 0:
axes[i, j].set_ylabel(a, visible=True)
#axes[i, j].yaxis.set_visible(True)
axes[i, j].set_ylabel(a)
axes[i, j].set_yticklabels(ticks)
axes[i, j].yaxis.set_ticks_position('left')
axes[i, j].yaxis.set_label_position('left')
if j == n - 1 and i % 2 == 1:
axes[i, j].set_ylabel(a, visible=True)
#axes[i, j].yaxis.set_visible(True)
axes[i, j].set_ylabel(a)
axes[i, j].set_yticklabels(ticks)
axes[i, j].yaxis.set_ticks_position('right')
axes[i, j].yaxis.set_label_position('right')

axes[i, j].grid(b=grid)

# ensure {x,y}lim off diagonal are the same as diagonal
for i in range(n):
for j in range(n):
if i != j:
axes[i, j].set_xlim(axes[j, j].get_xlim())
axes[i, j].set_ylim(axes[i, i].get_ylim())

return axes

def _gca():
Expand Down

0 comments on commit 3d5990f

Please sign in to comment.