Skip to content

Commit

Permalink
Add XGB explainability output (NVIDIA#3044)
Browse files Browse the repository at this point in the history
* Add XGB explainability output

* typo fix

* format fix
  • Loading branch information
ZiyueXu77 authored Oct 15, 2024
1 parent 3dc385f commit f717350
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 5 deletions.
8 changes: 8 additions & 0 deletions examples/advanced/finance-end-to-end/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,11 @@ Below is the output of last round of training (starting round = 0)
[9] eval-auc:0.72318 train-auc:0.72241
```
As shown, GNN embeddings help to promote the model performance by providing extra features beyond the hand-crafted ones.

For model explainability, our XGBoost training code will generate the feature importance plot of the XGBoost model with regard to validation data:
For normalized data without GNN features, the feature importance plot is shown below:
![feature_importance](./figures/shap_beeswarm_base.png)
For normalized data with GNN embeddings, the feature importance plot is shown below:
![feature_importance](./figures/shap_beeswarm_gnn.png)

As shown, the GNN embeddings provide additional features that are important for the model.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,19 @@ def __init__(self, root_dir: str, file_postfix: str):
self.file_postfix = file_postfix
for name in self.dataset_names:
self.base_file_names[name] = name + file_postfix
self.numerical_columns = [f"V_{i}" for i in range(64)]

self.numerical_columns = [
"Timestamp",
"Amount",
"trans_volume",
"total_amount",
"average_amount",
"hist_trans_volume",
"hist_total_amount",
"hist_average_amount",
"x2_y1",
"x3_y2",
] + [f"V_{i}" for i in range(64)]

def initialize(
self, client_id: str, rank: int, data_split_mode: xgb.core.DataSplitMode = xgb.core.DataSplitMode.ROW
Expand All @@ -40,11 +52,10 @@ def initialize(
def load_data(self) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
data = {}
for ds_name in self.dataset_names:
print("\nloading for site = ", self.client_id, f"{ds_name} dataset")
print("\nloading for site = ", self.client_id, f"{ds_name} dataset \n")
file_name = os.path.join(self.root_dir, self.client_id, self.base_file_names[ds_name])
print(file_name)
print(self.numerical_columns)
print("\n")
df = pd.read_csv(file_name)
data_num = len(data)

Expand Down
4 changes: 2 additions & 2 deletions examples/advanced/finance-end-to-end/nvflare/xgb_job_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def define_parser():
"--file_postfix",
type=str,
nargs="?",
default="_embedding.csv",
help="file ending postfix, such as '.csv', or '_embedding.csv'",
default="_combined.csv",
help="file ending postfix, such as '.csv', or '_combined.csv'",
)

parser.add_argument("-co", "--config_only", action="store_true", help="config only mode, will not run simulator")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os
from typing import Tuple

import matplotlib.pyplot as plt
import shap
import xgboost as xgb
from xgboost import callback

Expand Down Expand Up @@ -222,6 +224,16 @@ def run(self, ctx: dict):
bst.save_model(os.path.join(self._model_dir, self.model_file_name))
xgb.collective.communicator_print("Finished training\n")

# Save explanability outputs based on val_data
explainer = shap.TreeExplainer(bst)
explanation = explainer(val_data)

# save the beeswarm plot to png file
shap.plots.beeswarm(explanation, show=False)
img = plt.gcf()
img.subplots_adjust(left=0.3, right=0.9, bottom=0.3, top=0.9)
img.savefig(os.path.join(self._model_dir, "shap_beeswarm.png"), bbox_inches="tight")

self._stopped = True

def stop(self):
Expand Down

0 comments on commit f717350

Please sign in to comment.