Skip to content

Commit

Permalink
support parsnip interface for svm_linear()
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Jul 29, 2024
1 parent 9eabc26 commit 94c7665
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 0 deletions.
61 changes: 61 additions & 0 deletions R/mod-svm_linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#' m
#'
#' predict(m, matrix(rnorm(12), ncol = 3))
#' @keywords internal
#' @export
linfa_svm_linear <- function(x, y, cost = 1) {
check_x(x, y)
Expand All @@ -39,3 +40,63 @@ linfa_svm_linear <- function(x, y, cost = 1) {
predict.linfa_svm_linear <- function(object, newdata) {
predict_svm_linear(object$fit, c(newdata), n_features = ncol(object$ptype))
}

# nocov start

make_svm_linear_linfa <- function() {
parsnip::set_model_engine(
model = "svm_linear",
mode = "classification",
eng = "linfa"
)

parsnip::set_dependency(
model = "svm_linear",
eng = "linfa",
pkg = "rinfa",
mode = "classification"
)

parsnip::set_fit(
model = "svm_linear",
eng = "linfa",
mode = "classification",
value = list(
interface = "matrix",
protect = c("x", "y"),
func = c(pkg = "rinfa", fun = "linfa_svm_linear"),
defaults = list()
)
)

parsnip::set_encoding(
model = "svm_linear",
mode = "classification",
eng = "linfa",
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)

parsnip::set_pred(
model = "svm_linear",
eng = "linfa",
mode = "classification",
type = "class",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args = list(
object = quote(object$fit),
newdata = quote(new_data)
)
)
)
}

# nocov end

1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
make_logistic_reg_linfa()
make_multinom_reg_linfa()
make_naive_Bayes_linfa()
make_svm_linear_linfa()
}


Expand Down
1 change: 1 addition & 0 deletions man/linfa_svm_linear.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/test-mod-svm_linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,19 @@ test_that("linfa_svm_linear works", {
expect_length(p_linfa, 4)
})

test_that("parsnip interface works", {
set.seed(1)
x <- matrix(rnorm(300), ncol = 3)
y <- sample(1:4, 100, replace = TRUE)
newdata <- matrix(rnorm(12), ncol = 3)

m_linfa <- linfa_svm_linear(x, y)
p_linfa <- predict(m_linfa, newdata)

m_parsnip <- fit(svm_linear(engine = "linfa", mode = "classification"), y ~ ., cbind(as.data.frame(x), y = as.factor(y)))
p_parsnip <- predict(m_parsnip, as.data.frame(newdata))

expect_s3_class(m_parsnip, c("_linfa_svm_linear", "model_fit"))
expect_equal(p_linfa, as.integer(p_parsnip[[".pred_class"]]))
})

0 comments on commit 94c7665

Please sign in to comment.