Skip to content

Commit

Permalink
Extract base_values from explainer
Browse files Browse the repository at this point in the history
  • Loading branch information
fridrichmrtn committed Sep 13, 2022
1 parent 26f25ac commit 1408298
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions code/interpretation/compute-shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.cluster import AgglomerativeClustering
from mlflow.sklearn import load_model, save_model
import pyspark.sql.functions as f
from pyspark.sql.types import StructField, StructType, IntegerType
from pyspark.sql.types import StructField, StructType, IntegerType, FloatType

# COMMAND ----------

Expand Down Expand Up @@ -48,7 +48,7 @@ def _get_data(dataset_name, time_step):
df = spark.table(f"churndb.{dataset_name}_customer_model")
train = df.where(f.col("time_step")>time_step)
test = df.where(f.col("time_step")==time_step)
test = test.sample(fraction=1.0, seed=1)\
test = test.orderBy("row_id").sample(fraction=1.0, seed=1)\
.limit(1000).repartition(20)
return {"train":train, "test":test}

Expand All @@ -60,11 +60,14 @@ def _get_explainer(model, df):
def _compute_shap(explainer, df):
def _get_shap(iterator, explainer=explainer):
for X in iterator:
# possibly add expected values here
shap_instance = explainer(X.loc[:,_get_features(X)])
shap_values = np.column_stack((X.loc[:,"row_id"].values,
explainer(X.loc[:,_get_features(X)]).values))
shap_instance.values, shap_instance.base_values))
yield pd.DataFrame(shap_values)
schema = StructType([f for f in df.schema.fields\
if f.name in ["row_id"]+_get_features(df)])
if f.name in ["row_id"]+_get_features(df)]\
+[StructField("base_values", FloatType(), False)])
return df.mapInPandas(_get_shap, schema=schema)

def glue_shap(dataset_name, time_step, pipe):
Expand All @@ -85,6 +88,11 @@ def glue_shap(dataset_name, time_step, pipe):

# COMMAND ----------

# spark.table("churndb.retailrocket_shap_values").where(f.col("pipe")=="svm_rbf_class")\
# .toPandas().iloc[:,1:-1].sum(axis=1).sort_values()

# COMMAND ----------

spark.sql("DROP TABLE IF EXISTS churndb.rees46_shap_values;")
glue_shap("rees46", 0, "rf_class")
glue_shap("rees46", 0, "gbm_class")
glue_shap("rees46", 0, "gbm_reg")

0 comments on commit 1408298

Please sign in to comment.