Skip to content

Commit

Permalink
[jvm-packages] Objectives starting with rank: are never classification (
Browse files Browse the repository at this point in the history
dmlc#2837)

Training a model with the experimental rank:ndcg objective incorrectly
returns a Classification model. Adjust the classification check to
not recognize rank:* objectives as classification.

While writing tests for isClassificationTask also turned up that
obj_type -> regression was incorrectly identified as a classification
task so the function was slightly adjusted to pass the new tests.
  • Loading branch information
ebernhardson authored and superbobry committed Oct 30, 2017
1 parent 91af8f7 commit 46f2b82
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ object XGBoost extends Serializable {
val objective = params.getOrElse("objective", params.getOrElse("obj_type", null))
objective != null && {
val objStr = objective.toString
objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" &&
objStr != "rank:pairwise")
objStr != "regression" && !objStr.startsWith("reg:") && objStr != "count:poisson" &&
!objStr.startsWith("rank:")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,4 +389,28 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
}

test("isClassificationTask correctly classifies supported objectives") {
import org.scalatest.prop.TableDrivenPropertyChecks._

val objectives = Table(
("isClassificationTask", "params"),
(true, Map("obj_type" -> "classification")),
(false, Map("obj_type" -> "regression")),
(false, Map("objective" -> "rank:ndcg")),
(false, Map("objective" -> "rank:pairwise")),
(false, Map("objective" -> "rank:map")),
(false, Map("objective" -> "count:poisson")),
(true, Map("objective" -> "binary:logistic")),
(true, Map("objective" -> "binary:logitraw")),
(true, Map("objective" -> "multi:softmax")),
(true, Map("objective" -> "multi:softprob")),
(false, Map("objective" -> "reg:linear")),
(false, Map("objective" -> "reg:logistic")),
(false, Map("objective" -> "reg:gamma")),
(false, Map("objective" -> "reg:tweedie")))
forAll (objectives) { (isClassificationTask: Boolean, params: Map[String, String]) =>
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
}
}
}

0 comments on commit 46f2b82

Please sign in to comment.