Skip to content

Commit

Permalink
[SPARK-9716][ML] BinaryClassificationEvaluator should accept Double p…
Browse files Browse the repository at this point in the history
…rediction column

This PR aims to allow the prediction column of `BinaryClassificationEvaluator` to be of double type.

Author: BenFradet <[email protected]>

Closes apache#10472 from BenFradet/SPARK-9716.
  • Loading branch information
BenFradet authored and jkbradley committed Jan 19, 2016
1 parent 43f1d59 commit f6f7ca9
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
* Evaluator for binary classification, which expects two input columns: rawPrediction and label.
* The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1)
* or of type vector (length-2 vector of raw predictions, scores, or label probabilities).
*/
@Since("1.2.0")
@Experimental
Expand Down Expand Up @@ -78,13 +80,14 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.2.0")
override def evaluate(dataset: DataFrame): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT)
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)

// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol))
.map { case Row(rawPrediction: Vector, label: Double) =>
(rawPrediction(1), label)
.map {
case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metric = $(metricName) match {
Expand Down
17 changes: 17 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ private[spark] object SchemaUtils {
s"Column $colName must be of type $dataType but was actually $actualDataType.$message")
}

/**
* Check whether the given schema contains a column of one of the require data types.
* @param colName column name
* @param dataTypes required column data types
*/
def checkColumnTypes(
schema: StructType,
colName: String,
dataTypes: Seq[DataType],
msg: String = ""): Unit = {
val actualDataType = schema(colName).dataType
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(dataTypes.exists(actualDataType.equals),
s"Column $colName must be of type equal to one of the following types: " +
s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message")
}

/**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext

class BinaryClassificationEvaluatorSuite
Expand All @@ -36,4 +37,35 @@ class BinaryClassificationEvaluatorSuite
.setMetricName("areaUnderPR")
testDefaultReadWrite(evaluator)
}

test("should accept both vector and double raw prediction col") {
val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderPR")

val vectorDF = sqlContext.createDataFrame(Seq(
(0d, Vectors.dense(12, 2.5)),
(1d, Vectors.dense(1, 3)),
(0d, Vectors.dense(10, 2))
)).toDF("label", "rawPrediction")
assert(evaluator.evaluate(vectorDF) === 1.0)

val doubleDF = sqlContext.createDataFrame(Seq(
(0d, 0d),
(1d, 1d),
(0d, 0d)
)).toDF("label", "rawPrediction")
assert(evaluator.evaluate(doubleDF) === 1.0)

val stringDF = sqlContext.createDataFrame(Seq(
(0d, "0d"),
(1d, "1d"),
(0d, "0d")
)).toDF("label", "rawPrediction")
val thrown = intercept[IllegalArgumentException] {
evaluator.evaluate(stringDF)
}
assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " +
"equal to one of the following types: [DoubleType, ")
assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.")
}
}
5 changes: 3 additions & 2 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def isLargerBetter(self):
@inherit_doc
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol):
"""
Evaluator for binary classification, which expects two input
columns: rawPrediction and label.
Evaluator for binary classification, which expects two input columns: rawPrediction and label.
The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label
1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities).
>>> from pyspark.mllib.linalg import Vectors
>>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]),
Expand Down

0 comments on commit f6f7ca9

Please sign in to comment.