Skip to content

Commit

Permalink
support parsnip interface for naive_Bayes()
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Jul 29, 2024
1 parent 0e70486 commit 9eabc26
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 3 deletions.
66 changes: 63 additions & 3 deletions R/mod-naive_Bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
#' m
#'
#' predict(m, matrix(rnorm(12), ncol = 3))
#' @keywords internal
#' @export
linfa_naive_Bayes <- function(x, y, smoothness = 1e-9) {
check_x(x, y)
check_y(y, "classification")
# TODO: this is probably not the way... parsnip requires that the outcome
# is a factor, but linfa takes outcomes as integers
if (inherits(y, "factor")) {
# TODO: this is gross, but - 1 aligns levels(y) with y if y was coerced
# from integer
y <- as.integer(y) - 1L
y <- as.integer(y)
}

fit <-
Expand All @@ -42,3 +41,64 @@ linfa_naive_Bayes <- function(x, y, smoothness = 1e-9) {
predict.linfa_naive_Bayes <- function(object, newdata) {
predict_naive_Bayes(object$fit, c(newdata), n_features = ncol(object$ptype))
}


# nocov start

make_naive_Bayes_linfa <- function() {
parsnip::set_model_engine(
model = "naive_Bayes",
mode = "classification",
eng = "linfa"
)

parsnip::set_dependency(
model = "naive_Bayes",
eng = "linfa",
pkg = "rinfa",
mode = "classification"
)

parsnip::set_fit(
model = "naive_Bayes",
eng = "linfa",
mode = "classification",
value = list(
interface = "matrix",
protect = c("x", "y"),
func = c(pkg = "rinfa", fun = "linfa_naive_Bayes"),
defaults = list()
)
)

parsnip::set_encoding(
model = "naive_Bayes",
mode = "classification",
eng = "linfa",
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)

parsnip::set_pred(
model = "naive_Bayes",
eng = "linfa",
mode = "classification",
type = "class",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args = list(
object = quote(object$fit),
newdata = quote(new_data)
)
)
)
}

# nocov end

1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
make_linear_reg_linfa()
make_logistic_reg_linfa()
make_multinom_reg_linfa()
make_naive_Bayes_linfa()
}


Expand Down
1 change: 1 addition & 0 deletions man/linfa_naive_Bayes.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions tests/testthat/test-mod-naive_Bayes.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
test_that("parsnip interface works", {
set.seed(1)
x <- matrix(rnorm(300), ncol = 3)
y <- sample(1:4, 100, replace = TRUE)
newdata <- matrix(rnorm(12), ncol = 3)

m_linfa <- linfa_naive_Bayes(x, y)
p_linfa <- predict(m_linfa, newdata)

m_parsnip <- fit(naive_Bayes(engine = "linfa"), y ~ ., cbind(as.data.frame(x), y = as.factor(y)))
p_parsnip <- predict(m_parsnip, as.data.frame(newdata))

expect_s3_class(m_parsnip, c("_linfa_naive_Bayes", "model_fit"))
expect_equal(p_linfa, as.integer(p_parsnip[[".pred_class"]]))
})

0 comments on commit 9eabc26

Please sign in to comment.