Skip to content

Commit

Permalink
[SPARK-18401][SPARKR][ML] SparkR random forest should support output …
Browse files Browse the repository at this point in the history
…original label.

## What changes were proposed in this pull request?
SparkR ```spark.randomForest``` classification prediction should output original label rather than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291).

## How was this patch tested?
Add unit tests.

Author: Yanbo Liang <[email protected]>

Closes apache#15842 from yanboliang/spark-18401.
  • Loading branch information
yanboliang committed Nov 11, 2016
1 parent a335634 commit 5ddf694
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
24 changes: 24 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,10 @@ test_that("spark.randomForest Classification", {
expect_equal(stats$numTrees, 20)
expect_error(capture.output(stats), NA)
expect_true(length(capture.output(stats)) > 6)
# Test string prediction values
predictions <- collect(predict(model, data))$prediction
expect_equal(length(grep("setosa", predictions)), 50)
expect_equal(length(grep("versicolor", predictions)), 50)

modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
write.ml(model, modelPath)
Expand All @@ -947,6 +951,26 @@ test_that("spark.randomForest Classification", {
expect_equal(stats$numClasses, stats2$numClasses)

unlink(modelPath)

# Test numeric response variable
labelToIndex <- function(species) {
switch(as.character(species),
setosa = 0.0,
versicolor = 1.0,
virginica = 2.0
)
}
iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
data <- suppressWarnings(createDataFrame(iris[-5]))
model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16)
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
expect_equal(stats$numTrees, 20)
# Test numeric prediction values
predictions <- collect(predict(model, data))$prediction
expect_equal(length(grep("1.0", predictions)), 50)
expect_equal(length(grep("2.0", predictions)), 50)
})

test_that("spark.gbt", {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
Expand All @@ -35,6 +35,8 @@ private[r] class RandomForestClassifierWrapper private (
val formula: String,
val features: Array[String]) extends MLWritable {

import RandomForestClassifierWrapper._

private val rfcModel: RandomForestClassificationModel =
pipeline.stages(1).asInstanceOf[RandomForestClassificationModel]

Expand All @@ -46,14 +48,20 @@ private[r] class RandomForestClassifierWrapper private (
def summary: String = rfcModel.toDebugString

def transform(dataset: Dataset[_]): DataFrame = {
pipeline.transform(dataset).drop(rfcModel.getFeaturesCol)
pipeline.transform(dataset)
.drop(PREDICTED_LABEL_INDEX_COL)
.drop(rfcModel.getFeaturesCol)
}

override def write: MLWriter = new
RandomForestClassifierWrapper.RandomForestClassifierWrapperWriter(this)
}

private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] {

val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"

def fit( // scalastyle:ignore
data: DataFrame,
formula: String,
Expand All @@ -73,6 +81,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC

val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

Expand All @@ -82,6 +91,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
.attributes.get
val features = featureAttrs.map(_.name.get)

// get label names from output schema
val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
.asInstanceOf[NominalAttribute]
val labels = labelAttr.values.get

// assemble and fit the pipeline
val rfc = new RandomForestClassifier()
.setMaxDepth(maxDepth)
Expand All @@ -97,10 +111,16 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
.setCacheNodeIds(cacheNodeIds)
.setProbabilityCol(probabilityCol)
.setFeaturesCol(rFormula.getFeaturesCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)

val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
.setOutputCol(PREDICTED_LABEL_COL)
.setLabels(labels)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, rfc))
.setStages(Array(rFormulaModel, rfc, idxToStr))
.fit(data)

new RandomForestClassifierWrapper(pipeline, formula, features)
Expand Down

0 comments on commit 5ddf694

Please sign in to comment.