Skip to content

Commit

Permalink
[SPARK-21741][ML][PYSPARK] Python API for DataFrame-based multivariat…
Browse files Browse the repository at this point in the history
…e summarizer

## What changes were proposed in this pull request?

Python API for DataFrame-based multivariate summarizer.

## How was this patch tested?

doctest added.

Author: WeichenXu <[email protected]>

Closes apache#20695 from WeichenXu123/py_summarizer.
  • Loading branch information
WeichenXu123 authored and jkbradley committed Apr 17, 2018
1 parent f39e82c commit 1ca3c50
Showing 1 changed file with 192 additions and 1 deletion.
193 changes: 192 additions & 1 deletion python/pyspark/ml/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from pyspark import since, SparkContext
from pyspark.ml.common import _java2py, _py2java
from pyspark.ml.wrapper import _jvm
from pyspark.ml.wrapper import JavaWrapper, _jvm
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.functions import lit


class ChiSquareTest(object):
Expand Down Expand Up @@ -195,6 +197,195 @@ def test(dataset, sampleCol, distName, *params):
_jvm().PythonUtils.toSeq(params)))


class Summarizer(object):
"""
.. note:: Experimental
Tools for vectorized statistics on MLlib Vectors.
The methods in this package provide various statistics for Vectors contained inside DataFrames.
This class lets users pick the statistics they would like to extract for a given column.
>>> from pyspark.ml.stat import Summarizer
>>> from pyspark.sql import Row
>>> from pyspark.ml.linalg import Vectors
>>> summarizer = Summarizer.metrics("mean", "count")
>>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
>>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False)
+-----------------------------------+
|aggregate_metrics(features, weight)|
+-----------------------------------+
|[[1.0,1.0,1.0], 1] |
+-----------------------------------+
<BLANKLINE>
>>> df.select(summarizer.summary(df.features)).show(truncate=False)
+--------------------------------+
|aggregate_metrics(features, 1.0)|
+--------------------------------+
|[[1.0,1.5,2.0], 2] |
+--------------------------------+
<BLANKLINE>
>>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False)
+--------------+
|mean(features)|
+--------------+
|[1.0,1.0,1.0] |
+--------------+
<BLANKLINE>
>>> df.select(Summarizer.mean(df.features)).show(truncate=False)
+--------------+
|mean(features)|
+--------------+
|[1.0,1.5,2.0] |
+--------------+
<BLANKLINE>
.. versionadded:: 2.4.0
"""
@staticmethod
@since("2.4.0")
def mean(col, weightCol=None):
"""
return a column of mean summary
"""
return Summarizer._get_single_metric(col, weightCol, "mean")

@staticmethod
@since("2.4.0")
def variance(col, weightCol=None):
"""
return a column of variance summary
"""
return Summarizer._get_single_metric(col, weightCol, "variance")

@staticmethod
@since("2.4.0")
def count(col, weightCol=None):
"""
return a column of count summary
"""
return Summarizer._get_single_metric(col, weightCol, "count")

@staticmethod
@since("2.4.0")
def numNonZeros(col, weightCol=None):
"""
return a column of numNonZero summary
"""
return Summarizer._get_single_metric(col, weightCol, "numNonZeros")

@staticmethod
@since("2.4.0")
def max(col, weightCol=None):
"""
return a column of max summary
"""
return Summarizer._get_single_metric(col, weightCol, "max")

@staticmethod
@since("2.4.0")
def min(col, weightCol=None):
"""
return a column of min summary
"""
return Summarizer._get_single_metric(col, weightCol, "min")

@staticmethod
@since("2.4.0")
def normL1(col, weightCol=None):
"""
return a column of normL1 summary
"""
return Summarizer._get_single_metric(col, weightCol, "normL1")

@staticmethod
@since("2.4.0")
def normL2(col, weightCol=None):
"""
return a column of normL2 summary
"""
return Summarizer._get_single_metric(col, weightCol, "normL2")

@staticmethod
def _check_param(featuresCol, weightCol):
if weightCol is None:
weightCol = lit(1.0)
if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column):
raise TypeError("featureCol and weightCol should be a Column")
return featuresCol, weightCol

@staticmethod
def _get_single_metric(col, weightCol, metric):
col, weightCol = Summarizer._check_param(col, weightCol)
return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric,
col._jc, weightCol._jc))

@staticmethod
@since("2.4.0")
def metrics(*metrics):
"""
Given a list of metrics, provides a builder that it turns computes metrics from a column.
See the documentation of [[Summarizer]] for an example.
The following metrics are accepted (case sensitive):
- mean: a vector that contains the coefficient-wise mean.
- variance: a vector tha contains the coefficient-wise variance.
- count: the count of all vectors seen.
- numNonzeros: a vector with the number of non-zeros for each coefficients
- max: the maximum for each coefficient.
- min: the minimum for each coefficient.
- normL2: the Euclidian norm for each coefficient.
- normL1: the L1 norm of each coefficient (sum of the absolute values).
:param metrics:
metrics that can be provided.
:return:
an object of :py:class:`pyspark.ml.stat.SummaryBuilder`
Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
interface.
"""
sc = SparkContext._active_spark_context
js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics",
_to_seq(sc, metrics))
return SummaryBuilder(js)


class SummaryBuilder(JavaWrapper):
"""
.. note:: Experimental
A builder object that provides summary statistics about a given column.
Users should not directly create such builders, but instead use one of the methods in
:py:class:`pyspark.ml.stat.Summarizer`
.. versionadded:: 2.4.0
"""
def __init__(self, jSummaryBuilder):
super(SummaryBuilder, self).__init__(jSummaryBuilder)

@since("2.4.0")
def summary(self, featuresCol, weightCol=None):
"""
Returns an aggregate object that contains the summary of the column with the requested
metrics.
:param featuresCol:
a column that contains features Vector object.
:param weightCol:
a column that contains weight value. Default weight is 1.0.
:return:
an aggregate column that contains the statistics. The exact content of this
structure is determined during the creation of the builder.
"""
featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol)
return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc))


if __name__ == "__main__":
import doctest
import pyspark.ml.stat
Expand Down

0 comments on commit 1ca3c50

Please sign in to comment.