Skip to content

Commit

Permalink
Merge pull request tlverse#156 from tlverse/split-specific
Browse files Browse the repository at this point in the history
initial fold specific predictions
  • Loading branch information
jeremyrcoyle authored Aug 7, 2018
2 parents 4bdfd99 + 97150f4 commit f2edc5f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 3 deletions.
4 changes: 4 additions & 0 deletions R/Lrnr_base.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ Lrnr_base <- R6Class(

custom_chain = function(new_chain_fun = NULL) {
private$.custom_chain <- new_chain_fun
},
predict_fold = function(task, fold_number){
warning(self$name, " is not a cv-aware learner, so self$predict_fold reverts to self$predict")
self$predict(task)
}
),

Expand Down
22 changes: 20 additions & 2 deletions R/Lrnr_cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,24 @@ Lrnr_cv <- R6Class(
print("Lrnr_cv")
print(self$params$learner)
# todo: check if fit
},
predict_fold = function(task, fold_number=0){
if(fold_number!=0){
fold_fit <- self$fit_object$fold_fits[[fold_number]]
return(fold_fit$predict(task))
} else{
return(self$predict(task))
}
},
chain_fold = function(task, fold_number = 0){
predictions <- self$predict_fold(task, fold_number)
# Add predictions as new columns
new_col_names <- task$add_columns(self$fit_uuid, predictions)
# new_covariates = union(names(predictions),task$nodes$covariates)
return(task$next_in_chain(
covariates = names(predictions),
column_names = new_col_names
))
}
),

Expand All @@ -72,7 +90,7 @@ Lrnr_cv <- R6Class(
),

private = list(
.properties = c("wrapper"),
.properties = c("wrapper", "cv"),

.train_sublearners = function(task) {
# prefer folds from params, but default to folds from task
Expand Down Expand Up @@ -156,7 +174,7 @@ Lrnr_cv <- R6Class(
# }
# doing train and predict like this is stupid, but that's the paradigm
# (for now!)
folds <- private$.fit_object$folds
folds <- task$folds
fold_fits <- private$.fit_object$fold_fits

cv_predict <- function(fold, fold_fits, task) {
Expand Down
6 changes: 5 additions & 1 deletion R/Lrnr_sl.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ Lrnr_sl <- R6Class(
risk_dt[, coefficients := c(coefs, NA)]
}
return(risk_dt)
},
predict_fold = function(task, fold_number=0){
meta_task <- self$fit_object$cv_fit$chain_fold(task,fold_number)
meta_predictions <- self$fit_object$cv_meta_fit$predict(meta_task)
}
),

Expand All @@ -149,7 +153,7 @@ Lrnr_sl <- R6Class(
),

private = list(
.properties = c("wrapper"),
.properties = c("wrapper", "cv"),

.train_sublearners = function(task) {
# prefer folds from params, but default to folds from task
Expand Down
28 changes: 28 additions & 0 deletions tests/testthat/test-sl-fold.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
library(testthat)
context("test_sl.R -- Basic Lrnr_sl functionality")

options(sl3.verbose = TRUE)
library(sl3)
library(origami)
library(SuperLearner)

data(cpp_imputed)
covars <- c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs", "sexn")
outcome <- "haz"
task <- sl3_Task$new(data.table::copy(cpp_imputed), covariates = covars, outcome = outcome)

glm_learner <- Lrnr_glm$new()
glmnet_learner <- Lrnr_pkg_SuperLearner$new("SL.glmnet")
subset_apgar <- Lrnr_subset_covariates$new(covariates = c("apgar1", "apgar5"))
learners <- list(glm_learner, glmnet_learner, subset_apgar)
sl1 <- make_learner(Lrnr_sl, learners, glm_learner)

sl_fit <- sl1$train(task)

fold1_predict <- sl_fit$predict_fold(task,1)
validation_predict <- sl_fit$predict_fold(task,0)
expect_false(all(fold1_predict==validation_predict))
expect_true(any(fold1_predict==validation_predict))

glm_fit <- glm_learner$train(task)
expect_warning(glm_fold1_predict <- glm_fit$predict_fold(task,1))

0 comments on commit f2edc5f

Please sign in to comment.