forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH Add Categorical support for HistGradientBoosting (scikit-learn#18394
) Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Olivier Grisel <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
- Loading branch information
1 parent
04c080a
commit b4453f1
Showing
24 changed files
with
2,206 additions
and
182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import argparse | ||
from time import time | ||
|
||
from sklearn.model_selection import train_test_split | ||
from sklearn.datasets import fetch_openml | ||
from sklearn.metrics import accuracy_score, roc_auc_score | ||
from sklearn.experimental import enable_hist_gradient_boosting # noqa | ||
from sklearn.ensemble import HistGradientBoostingClassifier | ||
from sklearn.ensemble._hist_gradient_boosting.utils import ( | ||
get_equivalent_estimator) | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--n-leaf-nodes', type=int, default=31) | ||
parser.add_argument('--n-trees', type=int, default=100) | ||
parser.add_argument('--lightgbm', action="store_true", default=False) | ||
parser.add_argument('--learning-rate', type=float, default=.1) | ||
parser.add_argument('--max-bins', type=int, default=255) | ||
parser.add_argument('--no-predict', action="store_true", default=False) | ||
parser.add_argument('--verbose', action="store_true", default=False) | ||
args = parser.parse_args() | ||
|
||
n_leaf_nodes = args.n_leaf_nodes | ||
n_trees = args.n_trees | ||
lr = args.learning_rate | ||
max_bins = args.max_bins | ||
verbose = args.verbose | ||
|
||
|
||
def fit(est, data_train, target_train, libname, **fit_params): | ||
print(f"Fitting a {libname} model...") | ||
tic = time() | ||
est.fit(data_train, target_train, **fit_params) | ||
toc = time() | ||
print(f"fitted in {toc - tic:.3f}s") | ||
|
||
|
||
def predict(est, data_test, target_test): | ||
if args.no_predict: | ||
return | ||
tic = time() | ||
predicted_test = est.predict(data_test) | ||
predicted_proba_test = est.predict_proba(data_test) | ||
toc = time() | ||
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) | ||
acc = accuracy_score(target_test, predicted_test) | ||
print(f"predicted in {toc - tic:.3f}s, " | ||
f"ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}") | ||
|
||
|
||
data = fetch_openml(data_id=179, as_frame=False) # adult dataset | ||
X, y = data.data, data.target | ||
|
||
n_features = X.shape[1] | ||
n_categorical_features = len(data.categories) | ||
n_numerical_features = n_features - n_categorical_features | ||
print(f"Number of features: {n_features}") | ||
print(f"Number of categorical features: {n_categorical_features}") | ||
print(f"Number of numerical features: {n_numerical_features}") | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2, | ||
random_state=0) | ||
|
||
# Note: no need to use an OrdinalEncoder because categorical features are | ||
# already clean | ||
is_categorical = [name in data.categories for name in data.feature_names] | ||
est = HistGradientBoostingClassifier( | ||
loss='binary_crossentropy', | ||
learning_rate=lr, | ||
max_iter=n_trees, | ||
max_bins=max_bins, | ||
max_leaf_nodes=n_leaf_nodes, | ||
categorical_features=is_categorical, | ||
early_stopping=False, | ||
random_state=0, | ||
verbose=verbose | ||
) | ||
|
||
fit(est, X_train, y_train, 'sklearn') | ||
predict(est, X_test, y_test) | ||
|
||
if args.lightgbm: | ||
est = get_equivalent_estimator(est, lib='lightgbm') | ||
est.set_params(max_cat_to_onehot=1) # dont use OHE | ||
categorical_features = [f_idx | ||
for (f_idx, is_cat) in enumerate(is_categorical) | ||
if is_cat] | ||
fit(est, X_train, y_train, 'lightgbm', | ||
categorical_feature=categorical_features) | ||
predict(est, X_test, y_test) |
84 changes: 84 additions & 0 deletions
84
benchmarks/bench_hist_gradient_boosting_categorical_only.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import argparse | ||
from time import time | ||
|
||
from sklearn.preprocessing import KBinsDiscretizer | ||
from sklearn.datasets import make_classification | ||
from sklearn.experimental import enable_hist_gradient_boosting # noqa | ||
from sklearn.ensemble import HistGradientBoostingClassifier | ||
from sklearn.ensemble._hist_gradient_boosting.utils import ( | ||
get_equivalent_estimator) | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--n-leaf-nodes', type=int, default=31) | ||
parser.add_argument('--n-trees', type=int, default=100) | ||
parser.add_argument('--n-features', type=int, default=20) | ||
parser.add_argument('--n-cats', type=int, default=20) | ||
parser.add_argument('--n-samples', type=int, default=10_000) | ||
parser.add_argument('--lightgbm', action="store_true", default=False) | ||
parser.add_argument('--learning-rate', type=float, default=.1) | ||
parser.add_argument('--max-bins', type=int, default=255) | ||
parser.add_argument('--no-predict', action="store_true", default=False) | ||
parser.add_argument('--verbose', action="store_true", default=False) | ||
args = parser.parse_args() | ||
|
||
n_leaf_nodes = args.n_leaf_nodes | ||
n_features = args.n_features | ||
n_categories = args.n_cats | ||
n_samples = args.n_samples | ||
n_trees = args.n_trees | ||
lr = args.learning_rate | ||
max_bins = args.max_bins | ||
verbose = args.verbose | ||
|
||
|
||
def fit(est, data_train, target_train, libname, **fit_params): | ||
print(f"Fitting a {libname} model...") | ||
tic = time() | ||
est.fit(data_train, target_train, **fit_params) | ||
toc = time() | ||
print(f"fitted in {toc - tic:.3f}s") | ||
|
||
|
||
def predict(est, data_test): | ||
# We don't report accuracy or ROC because the dataset doesn't really make | ||
# sense: we treat ordered features as un-ordered categories. | ||
if args.no_predict: | ||
return | ||
tic = time() | ||
est.predict(data_test) | ||
toc = time() | ||
print(f"predicted in {toc - tic:.3f}s") | ||
|
||
|
||
X, y = make_classification(n_samples=n_samples, n_features=n_features, | ||
random_state=0) | ||
|
||
X = KBinsDiscretizer(n_bins=n_categories, encode='ordinal').fit_transform(X) | ||
|
||
print(f"Number of features: {n_features}") | ||
print(f"Number of samples: {n_samples}") | ||
|
||
is_categorical = [True] * n_features | ||
est = HistGradientBoostingClassifier( | ||
loss='binary_crossentropy', | ||
learning_rate=lr, | ||
max_iter=n_trees, | ||
max_bins=max_bins, | ||
max_leaf_nodes=n_leaf_nodes, | ||
categorical_features=is_categorical, | ||
early_stopping=False, | ||
random_state=0, | ||
verbose=verbose | ||
) | ||
|
||
fit(est, X, y, 'sklearn') | ||
predict(est, X) | ||
|
||
if args.lightgbm: | ||
est = get_equivalent_estimator(est, lib='lightgbm') | ||
est.set_params(max_cat_to_onehot=1) # dont use OHE | ||
categorical_features = list(range(n_features)) | ||
fit(est, X, y, 'lightgbm', | ||
categorical_feature=categorical_features) | ||
predict(est, X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.