Skip to content

Commit

Permalink
Rewrite approx (#7214)
Browse files Browse the repository at this point in the history
This PR rewrites the approx tree method to use codebase from hist for better performance and code sharing.

The rewrite has many benefits:
- Support for both `max_leaves` and `max_depth`.
- Support for `grow_policy`.
- Support for mono constraint.
- Support for feature weights.
- Support for easier bin configuration (`max_bin`).
- Support for categorical data.
- Faster performance for most of the datasets. (many times faster)
- Support for prediction cache.
- Significantly better performance for external memory.
- Unites the code base between approx and hist.
  • Loading branch information
trivialfis authored Jan 10, 2022
1 parent ed95e77 commit 0015031
Show file tree
Hide file tree
Showing 22 changed files with 635 additions and 264 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "../src/tree/updater_refresh.cc"
#include "../src/tree/updater_sync.cc"
#include "../src/tree/updater_histmaker.cc"
#include "../src/tree/updater_approx.cc"
#include "../src/tree/constraints.cc"

// linear
Expand Down
3 changes: 2 additions & 1 deletion demo/guide-python/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
=====================================
Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has
experimental support for one-hot encoding based tree split.
experimental support for one-hot encoding based tree split, and in 1.6 `approx` supported
was added.
In before, users need to run an encoder themselves before passing the data into XGBoost,
which creates a sparse matrix and potentially increase memory usage. This demo showcases
Expand Down
20 changes: 17 additions & 3 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Parameters for Tree Booster

* ``sketch_eps`` [default=0.03]

- Only used for ``tree_method=approx``.
- Only used for ``updater=grow_local_histmaker``.
- This roughly translates into ``O(1 / sketch_eps)`` number of bins.
Compared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy.
- Usually user does not have to tune this.
Expand Down Expand Up @@ -238,13 +238,27 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See :doc:`/tutorials/feature_interaction_constraint` for more information.

Additional parameters for ``hist`` and ``gpu_hist`` tree method
================================================================
Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
===========================================================================

* ``single_precision_histogram``, [default= ``false``]

- Use single precision to build histograms instead of double precision.

Additional parameters for ``approx`` tree method
================================================

* ``max_cat_to_onehot``

.. versionadded:: 1.6

.. note:: The support for this parameter is experimental.

- A threshold for deciding whether XGBoost should use one-hot encoding based split for
categorical data. When number of categories is lesser than the threshold then one-hot
encoding is chosen, otherwise the categories will be partitioned into children nodes.
Only relevant for regression and binary classification with `approx` tree method.

Additional parameters for Dart Booster (``booster=dart``)
=========================================================

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"max_bin" -> 16,
"max_bin" -> 64,
"tree_method" -> treeMethod)

val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
Expand Down
22 changes: 16 additions & 6 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)]
max_cat_to_onehot : bool
.. versionadded:: 1.6.0
A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold then
one-hot encoding is chosen, otherwise the categories will be partitioned into
children nodes. Only relevant for regression and binary classification and
`approx` tree method.
kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters
can be found :doc:`here </parameter>`.
Expand Down Expand Up @@ -483,6 +493,7 @@ def __init__(
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
callbacks: Optional[List[TrainingCallback]] = None,
max_cat_to_onehot: Optional[int] = None,
**kwargs: Any
) -> None:
if not SKLEARN_INSTALLED:
Expand Down Expand Up @@ -522,6 +533,7 @@ def __init__(
self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds
self.callbacks = callbacks
self.max_cat_to_onehot = max_cat_to_onehot
if kwargs:
self.kwargs = kwargs

Expand Down Expand Up @@ -800,8 +812,8 @@ def _duplicated(parameter: str) -> None:
_duplicated("callbacks")
callbacks = self.callbacks if self.callbacks is not None else callbacks

# lastly check categorical data support.
if self.enable_categorical and params.get("tree_method", None) != "gpu_hist":
tree_method = params.get("tree_method", None)
if self.enable_categorical and tree_method not in ("gpu_hist", "approx"):
raise ValueError(
"Experimental support for categorical data is not implemented for"
" current tree method yet."
Expand Down Expand Up @@ -876,8 +888,7 @@ def fit(
feature_weights :
Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
`exact` tree methods.
otherwise a `ValueError` is thrown.
callbacks :
.. deprecated: 1.6.0
Expand Down Expand Up @@ -1750,8 +1761,7 @@ def fit(
feature_weights :
Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
`exact` tree methods.
otherwise a `ValueError` is thrown.
callbacks :
.. deprecated: 1.6.0
Expand Down
6 changes: 3 additions & 3 deletions src/common/threading_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ class BlockedSpace2d {
template <typename Func>
void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
const size_t num_blocks_in_space = space.Size();
nthreads = std::min(nthreads, omp_get_max_threads());
nthreads = std::max(nthreads, 1);
CHECK_GE(nthreads, 1);

dmlc::OMPException exc;
#pragma omp parallel num_threads(nthreads)
Expand Down Expand Up @@ -277,9 +276,10 @@ inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {

inline int32_t OmpGetNumThreads(int32_t n_threads) {
if (n_threads <= 0) {
n_threads = omp_get_num_procs();
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
}
n_threads = std::min(n_threads, OmpGetThreadLimit());
n_threads = std::max(n_threads, 1);
return n_threads;
}
} // namespace common
Expand Down
2 changes: 1 addition & 1 deletion src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ void GBTree::ConfigureUpdaters() {
// calling this function.
break;
case TreeMethod::kApprox:
tparam_.updater_seq = "grow_histmaker,prune";
tparam_.updater_seq = "grow_histmaker";
break;
case TreeMethod::kExact:
tparam_.updater_seq = "grow_colmaker,prune";
Expand Down
2 changes: 1 addition & 1 deletion src/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
int grow_policy;

uint32_t max_cat_to_onehot{1};
uint32_t max_cat_to_onehot{4};

//----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child
Expand Down
1 change: 1 addition & 0 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
}
size_t size = categories.size() - begin;
categories_sizes.emplace_back(static_cast<Integer::Int>(size));
CHECK_NE(size, 0);
}
}

Expand Down
1 change: 1 addition & 0 deletions src/tree/tree_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ DMLC_REGISTRY_LINK_TAG(updater_refresh);
DMLC_REGISTRY_LINK_TAG(updater_prune);
DMLC_REGISTRY_LINK_TAG(updater_quantile_hist);
DMLC_REGISTRY_LINK_TAG(updater_histmaker);
DMLC_REGISTRY_LINK_TAG(updater_approx);
DMLC_REGISTRY_LINK_TAG(updater_sync);
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);
Expand Down
Loading

0 comments on commit 0015031

Please sign in to comment.