forked from rasbt/machine-learning-notes
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request rasbt#29 from rasbt/xgboost-cloud-gpu
xgboost-gpu
- Loading branch information
Showing
4 changed files
with
84 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Training an XGBoost Classifier Using Cloud GPUs Without Worrying About Infrastructure | ||
|
||
|
||
|
||
Code accompanying the blog article: [Training an XGBoost Classifier Using Cloud GPUs Without Worrying About Infrastructure]() | ||
|
||
|
||
|
||
Run code as follows | ||
|
||
|
||
|
||
```pip install lightning | ||
# run XGBoost classifier locally | ||
python my_xgboost_classifier.py | ||
# run XGBoost classifier locally via Lightning (if you have a GPU) | ||
pip install lightning | ||
lightning run app xgboost-cloud-gpu.py --setup | ||
# run XGBoost in Lightning cloud on a V100 | ||
lightning run app xgboost-cloud-gpu.py --cloud | ||
``` | ||
|
30 changes: 30 additions & 0 deletions
30
cloud-resources/xgboost-lightning-gpu/my_xgboost_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from sklearn import datasets | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import accuracy_score | ||
from xgboost import XGBClassifier | ||
from joblib import dump | ||
|
||
|
||
def run_classifier(save_as="my_model.joblib", use_gpu=False): | ||
digits = datasets.load_digits() | ||
features, targets = digits.images, digits.target | ||
features = features.reshape(-1, 8*8) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(features, targets, test_size=0.2, random_state=123) | ||
|
||
if use_gpu: | ||
model = XGBClassifier(tree_method='gpu_hist', gpu_id=0) | ||
else: | ||
model = XGBClassifier() | ||
|
||
model.fit(X_train, y_train) | ||
y_pred = model.predict(X_test) | ||
|
||
accuracy = accuracy_score(y_test, y_pred) | ||
print(f"Accuracy: {accuracy * 100.0:.2f}%") | ||
|
||
dump(model, filename=save_as) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_classifier() |
26 changes: 26 additions & 0 deletions
26
cloud-resources/xgboost-lightning-gpu/xgboost-cloud-gpu.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#!pip install xgboost | ||
#!pip install scikit-learn | ||
|
||
import lightning as L | ||
from lightning.app.storage import Drive | ||
from my_xgboost_classifier import run_classifier | ||
|
||
|
||
class RunCode(L.LightningWork): | ||
def __init__(self): | ||
|
||
# available GPUs and costs: https://lightning.ai/pricing/consumption-rates | ||
super().__init__(cloud_compute=L.CloudCompute("gpu-fast", disk_size=10)) | ||
|
||
# storage for outputs | ||
self.model_storage = Drive("lit://checkpoints") | ||
|
||
def run(self): | ||
# run model code | ||
model_path = "my_model.joblib" | ||
run_classifier(save_as=model_path, use_gpu=True) | ||
self.model_storage.put(model_path) | ||
|
||
|
||
component = RunCode() | ||
app = L.LightningApp(component) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters