forked from tlverse/sl3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest-cv.R
30 lines (23 loc) · 1015 Bytes
/
test-cv.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
context("test-cv.R -- Cross-validation fold handling")
library(origami)
data(cpp_imputed)
covars <- c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs", "sexn")
outcome <- "haz"
task <- sl3_Task$new(cpp_imputed, covariates = covars, outcome = outcome)
test_that("task will self-generate folds for 10-fold CV", expect_length(
task$folds,
10
))
glm_learner <- Lrnr_glm$new()
cv_glm <- Lrnr_cv$new(glm_learner)
cv_glm_fit <- cv_glm$train(task)
test_that("Lrnr_cv will use folds from task", expect_equal(task$folds, cv_glm_fit$fit_object$folds))
folds <- make_folds(cpp_imputed, V = 5)
task_2 <- sl3_Task$new(cpp_imputed, covariates = covars, outcome = outcome, folds = folds)
test_that("task will accept custom folds", expect_length(task_2$folds, 5))
cv_glm_2 <- Lrnr_cv$new(glm_learner, folds = make_folds(cpp_imputed, V = 10))
cv_glm_fit_2 <- cv_glm_2$train(task_2)
test_that("Lrnr_cv can override folds from task", expect_equal(
cv_glm_fit_2$params$folds,
cv_glm_fit_2$fit_object$folds
))