Skip to content

Commit

Permalink
ENH: Have pivot and pivot_table take similar arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
jsexauer committed Mar 14, 2014
1 parent 403f778 commit 6e7054a
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 77 deletions.
5 changes: 5 additions & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ API Changes
DataFrame returned by ``GroupBy.apply`` (:issue:`6124`). This facilitates
``DataFrame.stack`` operations where the name of the column index is used as
the name of the inserted column containing the pivoted data.

- The :func:`pivot_table`/:meth:`DataFrame.pivot_table` and :func:`crosstab` functions
now take arguments ``index`` and ``columns`` instead of ``rows`` and ``cols``. A
``FutureWarning`` is raised to alert that the old ``rows`` and ``cols`` arguments
will not be supported in a future release (:issue:`5505`)

Experimental Features
~~~~~~~~~~~~~~~~~~~~~
Expand Down
5 changes: 5 additions & 0 deletions doc/source/v0.14.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ These are out-of-bounds selections
# New output, 4-level MultiIndex
df_multi.set_index([df_multi.index, df_multi.index])

- The :func:`pivot_table`/:meth:`DataFrame.pivot_table` and :func:`crosstab` functions
now take arguments ``index`` and ``columns`` instead of ``rows`` and ``cols``. A
``FutureWarning`` is raised to alert that the old ``rows`` and ``cols`` arguments
will not be supported in a future release (:issue:`5505`)


MultiIndexing Using Slicers
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
102 changes: 76 additions & 26 deletions pandas/tools/pivot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# pylint: disable=E1103

import warnings

from pandas import Series, DataFrame
from pandas.core.index import MultiIndex
from pandas.tools.merge import concat
Expand All @@ -10,8 +12,8 @@
import numpy as np


def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
fill_value=None, margins=False, dropna=True):
def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
fill_value=None, margins=False, dropna=True, **kwarg):
"""
Create a spreadsheet-style pivot table as a DataFrame. The levels in the
pivot table will be stored in MultiIndex objects (hierarchical indexes) on
Expand All @@ -21,9 +23,9 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
----------
data : DataFrame
values : column to aggregate, optional
rows : list of column names or arrays to group on
index : list of column names or arrays to group on
Keys to group on the x-axis of the pivot table
cols : list of column names or arrays to group on
columns : list of column names or arrays to group on
Keys to group on the y-axis of the pivot table
aggfunc : function, default numpy.mean, or list of functions
If list of functions passed, the resulting pivot table will have
Expand All @@ -35,6 +37,8 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
Add all row / columns (e.g. for subtotal / grand totals)
dropna : boolean, default True
Do not include columns whose entries are all NaN
rows : kwarg only alias of index [deprecated]
cols : kwarg only alias of columns [deprecated]
Examples
--------
Expand All @@ -50,8 +54,8 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
7 bar two small 6
8 bar two large 7
>>> table = pivot_table(df, values='D', rows=['A', 'B'],
... cols=['C'], aggfunc=np.sum)
>>> table = pivot_table(df, values='D', index=['A', 'B'],
... columns=['C'], aggfunc=np.sum)
>>> table
small large
foo one 1 4
Expand All @@ -63,21 +67,43 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
-------
table : DataFrame
"""
rows = _convert_by(rows)
cols = _convert_by(cols)
# Parse old-style keyword arguments
rows = kwarg.pop('rows', None)
if rows is not None:
warnings.warn("rows is deprecated, use index", FutureWarning)
if index is None:
index = rows
else:
msg = "Can only specify either 'rows' or 'index'"
raise TypeError(msg)

cols = kwarg.pop('cols', None)
if cols is not None:
warnings.warn("cols is deprecated, use columns", FutureWarning)
if columns is None:
columns = cols
else:
msg = "Can only specify either 'cols' or 'columns'"
raise TypeError(msg)

if kwarg:
raise TypeError("Unexpected argument(s): %s" % kwarg.keys())

index = _convert_by(index)
columns = _convert_by(columns)

if isinstance(aggfunc, list):
pieces = []
keys = []
for func in aggfunc:
table = pivot_table(data, values=values, rows=rows, cols=cols,
table = pivot_table(data, values=values, index=index, columns=columns,
fill_value=fill_value, aggfunc=func,
margins=margins)
pieces.append(table)
keys.append(func.__name__)
return concat(pieces, keys=keys, axis=1)

keys = rows + cols
keys = index + columns

values_passed = values is not None
if values_passed:
Expand Down Expand Up @@ -106,7 +132,7 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
table = agged
if table.index.nlevels > 1:
to_unstack = [agged.index.names[i]
for i in range(len(rows), len(keys))]
for i in range(len(index), len(keys))]
table = agged.unstack(to_unstack)

if not dropna:
Expand All @@ -132,14 +158,14 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
table = table.fillna(value=fill_value, downcast='infer')

if margins:
table = _add_margins(table, data, values, rows=rows,
cols=cols, aggfunc=aggfunc)
table = _add_margins(table, data, values, rows=index,
cols=columns, aggfunc=aggfunc)

# discard the top level
if values_passed and not values_multi:
table = table[values[0]]

if len(rows) == 0 and len(cols) > 0:
if len(index) == 0 and len(columns) > 0:
table = table.T

return table
Expand Down Expand Up @@ -299,18 +325,18 @@ def _convert_by(by):
return by


def crosstab(rows, cols, values=None, rownames=None, colnames=None,
aggfunc=None, margins=False, dropna=True):
def crosstab(index, columns, values=None, rownames=None, colnames=None,
aggfunc=None, margins=False, dropna=True, **kwarg):
"""
Compute a simple cross-tabulation of two (or more) factors. By default
computes a frequency table of the factors unless an array of values and an
aggregation function are passed
Parameters
----------
rows : array-like, Series, or list of arrays/Series
index : array-like, Series, or list of arrays/Series
Values to group by in the rows
cols : array-like, Series, or list of arrays/Series
columns : array-like, Series, or list of arrays/Series
Values to group by in the columns
values : array-like, optional
Array of values to aggregate according to the factors
Expand All @@ -324,6 +350,8 @@ def crosstab(rows, cols, values=None, rownames=None, colnames=None,
Add row/column margins (subtotals)
dropna : boolean, default True
Do not include columns whose entries are all NaN
rows : kwarg only alias of index [deprecated]
cols : kwarg only alias of columns [deprecated]
Notes
-----
Expand Down Expand Up @@ -353,26 +381,48 @@ def crosstab(rows, cols, values=None, rownames=None, colnames=None,
-------
crosstab : DataFrame
"""
rows = com._maybe_make_list(rows)
cols = com._maybe_make_list(cols)
# Parse old-style keyword arguments
rows = kwarg.pop('rows', None)
if rows is not None:
warnings.warn("rows is deprecated, use index", FutureWarning)
if index is None:
index = rows
else:
msg = "Can only specify either 'rows' or 'index'"
raise TypeError(msg)

cols = kwarg.pop('cols', None)
if cols is not None:
warnings.warn("cols is deprecated, use columns", FutureWarning)
if columns is None:
columns = cols
else:
msg = "Can only specify either 'cols' or 'columns'"
raise TypeError(msg)

if kwarg:
raise TypeError("Unexpected argument(s): %s" % kwarg.keys())

index = com._maybe_make_list(index)
columns = com._maybe_make_list(columns)

rownames = _get_names(rows, rownames, prefix='row')
colnames = _get_names(cols, colnames, prefix='col')
rownames = _get_names(index, rownames, prefix='row')
colnames = _get_names(columns, colnames, prefix='col')

data = {}
data.update(zip(rownames, rows))
data.update(zip(colnames, cols))
data.update(zip(rownames, index))
data.update(zip(colnames, columns))

if values is None:
df = DataFrame(data)
df['__dummy__'] = 0
table = df.pivot_table('__dummy__', rows=rownames, cols=colnames,
table = df.pivot_table('__dummy__', index=rownames, columns=colnames,
aggfunc=len, margins=margins, dropna=dropna)
return table.fillna(0).astype(np.int64)
else:
data['__dummy__'] = values
df = DataFrame(data)
table = df.pivot_table('__dummy__', rows=rownames, cols=colnames,
table = df.pivot_table('__dummy__', index=rownames, columns=colnames,
aggfunc=aggfunc, margins=margins, dropna=dropna)
return table

Expand Down
Loading

0 comments on commit 6e7054a

Please sign in to comment.