Skip to content

Commit

Permalink
support parsnip interface for decision_tree()
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Jul 29, 2024
1 parent 109ab48 commit 0e70486
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
65 changes: 62 additions & 3 deletions R/mod-decision_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ linfa_decision_tree <- function(x, y, cost_complexity = 0.00001,
# TODO: this is probably not the way... parsnip requires that the outcome
# is a factor, but linfa takes outcomes as integers
if (inherits(y, "factor")) {
# TODO: this is gross, but - 1 aligns levels(y) with y if y was coerced
# from integer
y <- as.integer(y) - 1L
y <- as.integer(y)
}

check_integer(tree_depth)
Expand All @@ -47,3 +45,64 @@ linfa_decision_tree <- function(x, y, cost_complexity = 0.00001,
predict.linfa_decision_tree <- function(object, newdata) {
predict_decision_tree(object$fit, c(newdata), n_features = ncol(object$ptype))
}


# nocov start

make_decision_tree_linfa <- function() {
parsnip::set_model_engine(
model = "decision_tree",
mode = "classification",
eng = "linfa"
)

parsnip::set_dependency(
model = "decision_tree",
eng = "linfa",
pkg = "rinfa",
mode = "classification"
)

parsnip::set_fit(
model = "decision_tree",
eng = "linfa",
mode = "classification",
value = list(
interface = "matrix",
protect = c("x", "y"),
func = c(pkg = "rinfa", fun = "linfa_decision_tree"),
defaults = list()
)
)

parsnip::set_encoding(
model = "decision_tree",
mode = "classification",
eng = "linfa",
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)

parsnip::set_pred(
model = "decision_tree",
eng = "linfa",
mode = "classification",
type = "class",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args = list(
object = quote(object$fit),
newdata = quote(new_data)
)
)
)
}

# nocov end

1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# been loaded.

.onLoad <- function(libname, pkgname) {
make_decision_tree_linfa()
make_linear_reg_linfa()
make_logistic_reg_linfa()
make_multinom_reg_linfa()
Expand Down
17 changes: 17 additions & 0 deletions tests/testthat/test-mod-decision_tree.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
test_that("parsnip interface works", {
set.seed(1)
x <- matrix(rnorm(300), ncol = 3)
y <- sample(1:4, 100, replace = TRUE)
newdata <- matrix(rnorm(12), ncol = 3)

m_linfa <- linfa_decision_tree(x, y)
p_linfa <- predict(m_linfa, newdata)

skip("TODO: i see intermittent test failures. \nthese fits don't seem to be deterministic? or maybe hyperparameters differ slightly?")

m_parsnip <- fit(decision_tree(engine = "linfa", mode = "classification"), y ~ ., cbind(as.data.frame(x), y = as.factor(y)))
p_parsnip <- predict(m_parsnip, as.data.frame(newdata))

expect_s3_class(m_parsnip, c("_linfa_decision_tree", "model_fit"))
expect_equal(p_linfa, as.integer(p_parsnip[[".pred_class"]]))
})

0 comments on commit 0e70486

Please sign in to comment.