diff --git a/code/interpretation/compute-shap.py b/code/interpretation/compute-shap.py index d039a32..213f093 100644 --- a/code/interpretation/compute-shap.py +++ b/code/interpretation/compute-shap.py @@ -7,7 +7,7 @@ from sklearn.cluster import AgglomerativeClustering from mlflow.sklearn import load_model, save_model import pyspark.sql.functions as f -from pyspark.sql.types import StructField, StructType, IntegerType +from pyspark.sql.types import StructField, StructType, IntegerType, FloatType # COMMAND ---------- @@ -48,7 +48,7 @@ def _get_data(dataset_name, time_step): df = spark.table(f"churndb.{dataset_name}_customer_model") train = df.where(f.col("time_step")>time_step) test = df.where(f.col("time_step")==time_step) - test = test.sample(fraction=1.0, seed=1)\ + test = test.orderBy("row_id").sample(fraction=1.0, seed=1)\ .limit(1000).repartition(20) return {"train":train, "test":test} @@ -60,11 +60,14 @@ def _get_explainer(model, df): def _compute_shap(explainer, df): def _get_shap(iterator, explainer=explainer): for X in iterator: + # possibly add expected values here + shap_instance = explainer(X.loc[:,_get_features(X)]) shap_values = np.column_stack((X.loc[:,"row_id"].values, - explainer(X.loc[:,_get_features(X)]).values)) + shap_instance.values, shap_instance.base_values)) yield pd.DataFrame(shap_values) schema = StructType([f for f in df.schema.fields\ - if f.name in ["row_id"]+_get_features(df)]) + if f.name in ["row_id"]+_get_features(df)]\ + +[StructField("base_values", FloatType(), False)]) return df.mapInPandas(_get_shap, schema=schema) def glue_shap(dataset_name, time_step, pipe): @@ -85,6 +88,11 @@ def glue_shap(dataset_name, time_step, pipe): # COMMAND ---------- +# spark.table("churndb.retailrocket_shap_values").where(f.col("pipe")=="svm_rbf_class")\ +# .toPandas().iloc[:,1:-1].sum(axis=1).sort_values() + +# COMMAND ---------- + spark.sql("DROP TABLE IF EXISTS churndb.rees46_shap_values;") -glue_shap("rees46", 0, "rf_class") +glue_shap("rees46", 0, "gbm_class") glue_shap("rees46", 0, "gbm_reg")