The goal of xplainfi
is to collect common feature importance methods
under a unified and extensible interface.
For now, it is built specifically around mlr3, as available abstractions for learners, tasks, measures, etc. greatly simplify the implementation of importance measures.
You can install the development version of xplainfi
like so:
# install.packages(pak)
pak::pak("jemus42/xplainfi")
Here is a basic example on how to calculate PFI for a given learner and task, using repeated cross-validation as resampling strategy and computing PFI within each resampling 5 times:
library(xplainfi)
library(mlr3)
library(mlr3learners)
task = tsk("german_credit")
learner = lrn("classif.ranger", num.trees = 100)
measure = msr("classif.ce")
pfi = PFI$new(
task = task,
learner = learner,
measure = measure,
resampling = rsmp("repeated_cv", folds = 3, repeats = 2),
iters_perm = 5
)
Compute and print PFI scores:
pfi$compute()
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: age 9.929091e-04
#> 2: amount 1.288294e-02
#> 3: credit_history 1.218554e-02
#> 4: duration 1.598605e-02
#> 5: employment_duration 3.890717e-03
#> 6: foreign_worker -1.202700e-03
#> 7: housing -8.016999e-04
#> 8: installment_rate 3.599408e-03
#> 9: job -1.002799e-03
#> 10: number_credits -2.402103e-03
#> 11: other_debtors 5.898713e-03
#> 12: other_installment_plans -9.095922e-04
#> 13: people_liable 5.994018e-07
#> 14: personal_status_sex -1.807496e-03
#> 15: present_residence 6.944070e-04
#> 16: property 1.291111e-03
#> 17: purpose 2.486918e-03
#> 18: savings 1.819694e-02
#> 19: status 3.978829e-02
#> 20: telephone 1.293209e-03
#> feature importance
Retrieve scores later in pfi$importance
.
When PFI is computed based on resampling with multiple iterations, and /
or multiple permutation iterations, the individual scores can be
retrieved as a data.table
:
pfi$scores
#> Key: <feature, iter_rsmp, iter_perm>
#> feature iter_rsmp iter_perm classif.ce_orig classif.ce_perm importance
#> <char> <int> <int> <num> <num> <num>
#> 1: age 1 1 0.2095808 0.2305389 0.020958084
#> 2: age 1 2 0.2095808 0.2335329 0.023952096
#> 3: age 1 3 0.2095808 0.2275449 0.017964072
#> 4: age 1 4 0.2095808 0.2215569 0.011976048
#> 5: age 1 5 0.2095808 0.2155689 0.005988024
#> ---
#> 596: telephone 6 1 0.2612613 0.2432432 -0.018018018
#> 597: telephone 6 2 0.2612613 0.2552553 -0.006006006
#> 598: telephone 6 3 0.2612613 0.2612613 0.000000000
#> 599: telephone 6 4 0.2612613 0.2522523 -0.009009009
#> 600: telephone 6 5 0.2612613 0.2402402 -0.021021021
Where iter_rsmp
corresponds to the resampling iteration, i.e., 3 * 2
= 6 for 2 repeats of 3-fold cross-validation, and iter_perm
corresponds to the permutation iteration, 5 in this case. While
pfi$importance
contains the means across all iterations, pfi$scores
allows you to manually aggregate them in any way you see fit.
In the simplest case, you run PFI with a single resampling iteration
(holdout) and a single permutation iteration, and pfi$importance
will
contain the same values as pfi$scores
.
pfi_single = PFI$new(
task = task,
learner = learner,
measure = measure
)
pfi_single$compute()
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: age 0.003003003
#> 2: amount 0.012012012
#> 3: credit_history 0.024024024
#> 4: duration 0.012012012
#> 5: employment_duration 0.006006006
#> 6: foreign_worker 0.000000000
#> 7: housing 0.006006006
#> 8: installment_rate 0.024024024
#> 9: job -0.003003003
#> 10: number_credits -0.003003003
#> 11: other_debtors 0.012012012
#> 12: other_installment_plans 0.006006006
#> 13: people_liable 0.009009009
#> 14: personal_status_sex 0.003003003
#> 15: present_residence 0.006006006
#> 16: property 0.003003003
#> 17: purpose 0.015015015
#> 18: savings 0.003003003
#> 19: status 0.054054054
#> 20: telephone -0.003003003
#> feature importance
pfi_single$scores
#> Key: <feature, iter_rsmp, iter_perm>
#> feature iter_rsmp iter_perm classif.ce_orig classif.ce_perm
#> <char> <int> <int> <num> <num>
#> 1: age 1 1 0.2732733 0.2762763
#> 2: amount 1 1 0.2732733 0.2852853
#> 3: credit_history 1 1 0.2732733 0.2972973
#> 4: duration 1 1 0.2732733 0.2852853
#> 5: employment_duration 1 1 0.2732733 0.2792793
#> 6: foreign_worker 1 1 0.2732733 0.2732733
#> 7: housing 1 1 0.2732733 0.2792793
#> 8: installment_rate 1 1 0.2732733 0.2972973
#> 9: job 1 1 0.2732733 0.2702703
#> 10: number_credits 1 1 0.2732733 0.2702703
#> 11: other_debtors 1 1 0.2732733 0.2852853
#> 12: other_installment_plans 1 1 0.2732733 0.2792793
#> 13: people_liable 1 1 0.2732733 0.2822823
#> 14: personal_status_sex 1 1 0.2732733 0.2762763
#> 15: present_residence 1 1 0.2732733 0.2792793
#> importance
#> <num>
#> 1: 0.003003003
#> 2: 0.012012012
#> 3: 0.024024024
#> 4: 0.012012012
#> 5: 0.006006006
#> 6: 0.000000000
#> 7: 0.006006006
#> 8: 0.024024024
#> 9: -0.003003003
#> 10: -0.003003003
#> 11: 0.012012012
#> 12: 0.006006006
#> 13: 0.009009009
#> 14: 0.003003003
#> 15: 0.006006006
#> [ reached getOption("max.print") -- omitted 6 rows ]