Skip to content

Commit

Permalink
[jvm-packages] Add some documentation to xgboost4j-spark plus minor s…
Browse files Browse the repository at this point in the history
…tyle edits (dmlc#2823)

* add scala docs to several methods

* indentation

* license formatting

* clarify distributed boosters

* address some review comments

* reduce doc lengths

* change method name, clarify  doc

* reset make config

* delete most comments

* more review feedback
  • Loading branch information
sethah authored and CodingCat committed Nov 2, 2017
1 parent 46f2b82 commit a8f670d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,30 @@ import org.apache.spark.sql.Dataset
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}

object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
}

/**
* Rabit tracker configurations.
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* Rabit tracker configurations.
*
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*/
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*/
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)

object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
}

object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")

private def fromDenseToSparseLabeledPoints(
private def removeMissingValues(
denseLabeledPoints: Iterator[XGBLabeledPoint],
missing: Float): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) {
Expand Down Expand Up @@ -89,7 +91,7 @@ object XGBoost extends Serializable {
} else {
throw new IllegalArgumentException(
s"Encountered a partition with $nUndefined NaN base margin values. " +
"If you want to specify base margin, ensure all values are non-NaN.")
s"If you want to specify base margin, ensure all values are non-NaN.")
}
}

Expand Down Expand Up @@ -118,23 +120,23 @@ object XGBoost extends Serializable {
if (labeledPoints.isEmpty) {
throw new XGBoostError(
s"detected an empty partition in the training data, partition ID:" +
s" ${TaskContext.getPartitionId()}")
s" ${TaskContext.getPartitionId()}")
}
val cacheFileName = if (useExternalMemory) {
s"$appName-${TaskContext.get().stageId()}-" +
s"dtrain_cache-${TaskContext.getPartitionId()}"
s"dtrain_cache-${TaskContext.getPartitionId()}"
} else {
null
}
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv)
val watches = Watches(params,
fromDenseToSparseLabeledPoints(labeledPoints, missing),
removeMissingValues(labeledPoints, missing),
fromBaseMarginsToArray(baseMargins), cacheFileName)

try {
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
.map(_.toString.toInt).getOrElse(0)
.map(_.toString.toInt).getOrElse(0)
val booster = SXGBoost.train(watches.train, params, round,
watches = watches.toMap, obj = obj, eval = eval,
earlyStoppingRound = numEarlyStoppingRounds)
Expand All @@ -147,17 +149,18 @@ object XGBoost extends Serializable {
}

/**
* train XGBoost model with the DataFrame-represented data
* @param trainingData the trainingset represented as DataFrame
* Train XGBoost model with the DataFrame-represented data
*
* @param trainingData the training set represented as DataFrame
* @param params Map containing the parameters to configure XGBoost
* @param round the number of iterations
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
* workers equals to the partition number of trainingData RDD
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
* @param missing The value which represents a missing value in the dataset
* @param featureCol the name of input column, "features" as default value
* @param labelCol the name of output column, "label" as default value
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
Expand Down Expand Up @@ -200,14 +203,15 @@ object XGBoost extends Serializable {
}

/**
* train XGBoost model with the RDD-represented data
* @param trainingData the trainingset represented as RDD
* Train XGBoost model with the RDD-represented data
*
* @param trainingData the training set represented as RDD
* @param params Map containing the configuration entries
* @param round the number of iterations
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
* workers equals to the partition number of trainingData RDD
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
Expand All @@ -224,8 +228,7 @@ object XGBoost extends Serializable {
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory,
missing)
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
}

private def overrideParamsAccordingToTaskCPUs(
Expand Down Expand Up @@ -256,18 +259,19 @@ object XGBoost extends Serializable {
}

/**
* various of train()
* @param trainingData the trainingset represented as RDD
* Train XGBoost model with the RDD-represented data
*
* @param trainingData the training set represented as RDD
* @param params Map containing the configuration entries
* @param round the number of iterations
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
* workers equals to the partition number of trainingData RDD
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @param missing The value which represents a missing value in the dataset
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training has failed
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
Expand Down Expand Up @@ -300,19 +304,19 @@ object XGBoost extends Serializable {
missing: Float = Float.NaN): XGBoostModel = {
if (params.contains("tree_method")) {
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
" for now")
" for now")
}
require(nWorkers > 0, "you must specify more than 0 workers")
if (obj != null) {
require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
" you have to specify the objective type as classification or regression with a" +
" customized objective function")
" you have to specify the objective type as classification or regression with a" +
" customized objective function")
}
val trackerConf = params.get("tracker_conf") match {
case None => TrackerConf()
case Some(conf: TrackerConf) => conf
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
"instance of TrackerConf.")
"instance of TrackerConf.")
}
val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match {
case None => 0L
Expand All @@ -339,8 +343,7 @@ object XGBoost extends Serializable {
val isClsTask = isClassificationTask(params)
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, overriddenParams,
sparkJobThread, isClsTask)
val model = postTrackerReturnProcessing(trackerReturnVal, boosters, sparkJobThread, isClsTask)
if (isClsTask){
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt
Expand All @@ -352,10 +355,13 @@ object XGBoost extends Serializable {
}

private def postTrackerReturnProcessing(
trackerReturnVal: Int, distributedBoosters: RDD[Booster],
params: Map[String, Any], sparkJobThread: Thread, isClassificationTask: Boolean):
XGBoostModel = {
trackerReturnVal: Int,
distributedBoosters: RDD[Booster],
sparkJobThread: Thread,
isClassificationTask: Boolean): XGBoostModel = {
if (trackerReturnVal == 0) {
// Copies of the finished model reside in each partition of the `distributedBoosters`.
// Any of them can be used to create the model. Here, just choose the first partition.
val xgboostModel = XGBoostModel(distributedBoosters.first(), isClassificationTask)
distributedBoosters.unpersist(false)
xgboostModel
Expand All @@ -365,7 +371,7 @@ object XGBoost extends Serializable {
sparkJobThread.interrupt()
}
} catch {
case ie: InterruptedException =>
case _: InterruptedException =>
logger.info("spark job thread is interrupted")
}
throw new XGBoostError("XGBoostModel training failed")
Expand All @@ -380,8 +386,10 @@ object XGBoost extends Serializable {
}

private def setGeneralModelParams(
featureCol: String, labelCol: String, predCol: String, xgBoostModel: XGBoostModel):
XGBoostModel = {
featureCol: String,
labelCol: String,
predCol: String,
xgBoostModel: XGBoostModel): XGBoostModel = {
xgBoostModel.setFeaturesCol(featureCol)
xgBoostModel.setLabelCol(labelCol)
xgBoostModel.setPredictionCol(predCol)
Expand Down Expand Up @@ -422,13 +430,17 @@ object XGBoost extends Serializable {
case "_reg_" =>
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
case other =>
throw new XGBoostError(s"Unknown model type $other. Supported types " +
s"are: ['_reg_', '_cls_'].")
}
}
}

private class Watches private(val train: DMatrix, val test: DMatrix) {

def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
.filter { case (_, matrix) => matrix.rowNum > 0 }
.filter { case (_, matrix) => matrix.rowNum > 0 }

def size: Int = toMap.size

Expand All @@ -440,6 +452,7 @@ private class Watches private(val train: DMatrix, val test: DMatrix) {
}

private object Watches {

def apply(
params: Map[String, Any],
labeledPoints: Iterator[XGBLabeledPoint],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@

package ml.dmlc.xgboost4j

/** Labeled training data point. */
/**
* Labeled training data point.
*
* @param label Label of this point.
* @param indices Feature indices of this point or `null` if the data is dense.
* @param values Feature values of this point.
* @param weight Weight of this point.
* @param group Group of this point (used for ranking) or -1.
* @param baseMargin Initial prediction on this point or `Float.NaN`
*/
case class LabeledPoint(
/** Label of this point. */
label: Float,
/** Feature indices of this point or `null` if the data is dense. */
indices: Array[Int],
/** Feature values of this point. */
values: Array[Float],
/** Weight of this point. */
weight: Float = 1.0f,
/** Group of this point (used for ranking) or -1. */
weight: Float = 1f,
group: Int = -1,
/** Initial prediction on this point or `Float.NaN`. */
baseMargin: Float = Float.NaN
) extends Serializable {
baseMargin: Float = Float.NaN) extends Serializable {
require(indices == null || indices.length == values.length,
"indices and values must have the same number of elements")

Expand Down

0 comments on commit a8f670d

Please sign in to comment.