Skip to content

Commit

Permalink
[SPARK-18481][ML] ML 2.1 QA: Remove deprecated methods for ML
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
Remove deprecated methods for ML.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <[email protected]>

Closes apache#15913 from yanboliang/spark-18481.
  • Loading branch information
yanboliang committed Nov 26, 2016
1 parent a88329d commit c4a7eef
Show file tree
Hide file tree
Showing 16 changed files with 144 additions and 107 deletions.
4 changes: 4 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ abstract class PipelineStage extends Params with Logging {
*
* Check transform validity and derive the output schema from the input schema.
*
* We check validity for interactions between parameters during `transformSchema` and
* raise an exception if any parameter value is invalid. Parameter value checks which
* do not depend on other parameters are handled by `Param.validate()`.
*
* Typical implementation should first conduct verification on schema change and parameter
* validity, including complex parameter interaction checks.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees

/**
* Number of trees in ensemble
*/
@Since("2.0.0")
val getNumTrees: Int = trees.length

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils

Expand Down Expand Up @@ -176,8 +176,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
}
}

override def validateParams(): Unit = {
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
checkThresholdConsistency()
super.validateAndTransformSchema(schema, fitting, featuresDataType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
Expand Down Expand Up @@ -221,15 +221,6 @@ class RandomForestClassificationModel private[ml] (
}
}

/**
* Number of trees in ensemble
*
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
*/
// TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,6 @@ final class ChiSqSelectorModel private[ml] (
@Since("1.6.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* @group setParam
*/
@Since("1.6.0")
@deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
Expand Down
15 changes: 0 additions & 15 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -546,21 +546,6 @@ trait Params extends Identifiable with Serializable {
.map(m => m.invoke(this).asInstanceOf[Param[_]])
}

/**
* Validates parameter values stored internally.
* Raise an exception if any parameter value is invalid.
*
* This only needs to check for interactions between parameters.
* Parameter value checks which do not depend on other parameters are handled by
* `Param.validate()`. This method does not handle input/output column parameters;
* those are checked during schema validation.
* @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
*/
@deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
def validateParams(): Unit = {
// Do nothing by default. Override to handle Param interactions.
}

/**
* Explains a param.
* @param param input param, must belong to this instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ class GBTRegressionModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees

/**
* Number of trees in ensemble
*/
@Since("2.0.0")
val getNumTrees: Int = trees.length

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,6 @@ class LinearRegressionSummary private[regression] (
private val privateModel: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {

@deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0")
val model: LinearRegressionModel = privateModel

@transient private val metrics = new RegressionMetrics(
predictions
.select(col(predictionCol), col(labelCol).cast(DoubleType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
Expand Down Expand Up @@ -182,14 +182,6 @@ class RandomForestRegressionModel private[ml] (
_trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}

/**
* Number of trees in ensemble
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
*/
// TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
/** Trees in this ensemble. Warning: These have null parent Estimators. */
def trees: Array[M]

/**
* Number of trees in ensemble
*/
val getNumTrees: Int = trees.length

/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]

Expand Down
90 changes: 39 additions & 51 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,32 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}

/** Used for [[RandomForestParams]] */
private[ml] trait HasFeatureSubsetStrategy extends Params {
/**
* Parameters for Random Forest algorithms.
*/
private[ml] trait RandomForestParams extends TreeEnsembleParams {

/**
* Number of trees to train (>= 1).
* If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* TODO: Change to always do bootstrapping (simpler). SPARK-7130
* (default = 20)
*
* Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
* is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
* are a bit different.
* @group param
*/
final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
ParamValidators.gtEq(1))

setDefault(numTrees -> 20)

/** @group setParam */
def setNumTrees(value: Int): this.type = set(numTrees, value)

/** @group getParam */
final def getNumTrees: Int = $(numTrees)

/**
* The number of features to consider for splits at each tree node.
Expand Down Expand Up @@ -366,38 +390,6 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
}

/**
* Used for [[RandomForestParams]].
* This is separated out from [[RandomForestParams]] because of an issue with the
* `numTrees` method conflicting with this Param in the Estimator.
*/
private[ml] trait HasNumTrees extends Params {

/**
* Number of trees to train (>= 1).
* If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* TODO: Change to always do bootstrapping (simpler). SPARK-7130
* (default = 20)
* @group param
*/
final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
ParamValidators.gtEq(1))

setDefault(numTrees -> 20)

/** @group setParam */
def setNumTrees(value: Int): this.type = set(numTrees, value)

/** @group getParam */
final def getNumTrees: Int = $(numTrees)
}

/**
* Parameters for Random Forest algorithms.
*/
private[ml] trait RandomForestParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with HasNumTrees

private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Expand All @@ -407,21 +399,15 @@ private[spark] object RandomForestParams {
private[ml] trait RandomForestClassifierParams
extends RandomForestParams with TreeClassifierParams

private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with TreeClassifierParams

private[ml] trait RandomForestRegressorParams
extends RandomForestParams with TreeRegressorParams

private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with TreeRegressorParams

/**
* Parameters for Gradient-Boosted Tree algorithms.
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize {
private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {

/* TODO: Add this doc when we add this param. SPARK-7132
* Threshold for stopping early when runWithValidation is used.
Expand All @@ -434,24 +420,26 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
// final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
// validationTol -> 1e-5

setDefault(maxIter -> 20, stepSize -> 0.1)

/** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)

/**
* Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
* estimator.
* Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking
* the contribution of each estimator.
* (default = 0.1)
* @group setParam
* @group param
*/
final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " +
"(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))

/** @group getParam */
final def getStepSize: Double = $(stepSize)

/** @group setParam */
def setStepSize(value: Double): this.type = set(stepSize, value)

override def validateParams(): Unit = {
require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)(
getStepSize), "GBT parameter stepSize should be in interval (0, 1], " +
s"but it given invalid value $getStepSize.")
}
setDefault(maxIter -> 20, stepSize -> 0.1)

/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private[util] sealed trait BaseReadWrite {
* Sets the Spark SQLContext to use for saving/loading.
*/
@Since("1.6.0")
@deprecated("Use session instead", "2.0.0")
@deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0")
def context(sqlContext: SQLContext): this.type = {
optionSparkSession = Option(sqlContext.sparkSession)
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
ParamsSuite.checkParams(model)
}

test("GBT parameter stepSize should be in interval (0, 1]") {
withClue("GBT parameter stepSize should be in interval (0, 1]") {
intercept[IllegalArgumentException] {
new GBTClassifier().setStepSize(10)
}
}
}

test("Binary classification with continuous features: Log Loss") {
val categoricalFeatures = Map.empty[Int, Int]
testCombinations.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ class LogisticRegressionSuite
}
}
// thresholds and threshold must be consistent: values
withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
intercept[IllegalArgumentException] {
lr2.fit(smallBinaryDataset,
lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
}
}
withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
intercept[IllegalArgumentException] {
val lr2model = lr2.fit(smallBinaryDataset,
Expand Down
30 changes: 30 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,36 @@ object MimaExcludes {
// [SPARK-12221] Add CPU time to metrics
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
) ++ Seq(
// [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassifier"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassificationModel"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressor"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressionModel"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
)
}

Expand Down
Loading

0 comments on commit c4a7eef

Please sign in to comment.