Skip to content

Commit

Permalink
Fix multi predict penalty issue with logistic regression as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
DavisVaughan committed Dec 6, 2018
1 parent 30c74d8 commit 01dc9bc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
16 changes: 8 additions & 8 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' `logistic_reg` is a way to generate a _specification_ of a model
#' before fitting and allows the model to be created using
#' different packages in R, Stan, keras, or via Spark. The main
#' different packages in R, Stan, keras, or via Spark. The main
#' arguments for the model are:
#' \itemize{
#' \item \code{penalty}: The total amount of regularization
Expand Down Expand Up @@ -65,7 +65,7 @@
#' \pkg{keras}
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")}
#'
#'
#' When using `glmnet` models, there is the option to pass
#' multiple values (or no values) to the `penalty` argument.
#' This can have an effect on the model object results. When using
Expand Down Expand Up @@ -240,7 +240,7 @@ organize_glmnet_prob <- function(x, object) {
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
}
Expand All @@ -249,7 +249,7 @@ predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...)
predict_class._lognet <- function (object, new_data, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_class.model_fit(object, new_data = new_data, ...)
}
Expand All @@ -258,7 +258,7 @@ predict_class._lognet <- function (object, new_data, ...) {
predict_classprob._lognet <- function (object, new_data, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_classprob.model_fit(object, new_data = new_data, ...)
}
Expand All @@ -267,7 +267,7 @@ predict_classprob._lognet <- function (object, new_data, ...) {
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}
Expand All @@ -280,10 +280,10 @@ multi_predict._lognet <-
function(object, new_data, type = NULL, penalty = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

dots <- list(...)
if (is.null(penalty))
penalty <- object$lambda
penalty <- object$fit$lambda
dots$s <- penalty

if (is.null(type))
Expand Down
11 changes: 9 additions & 2 deletions tests/testthat/test_logistic_reg_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,18 @@ test_that('submodel prediction', {
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], penalty = .1, type = "prob")
mp_res <- do.call("rbind", mp_res$.pred)
expect_equal(mp_res[[".pred_No"]], unname(pred_glmn[,1]))

expect_error(
multi_predict(class_fit, newdata = wa_churn[1:4, vars], penalty = .1, type = "prob"),
multi_predict(class_fit, newdata = wa_churn[1:4, vars], penalty = .1, type = "prob"),
"Did you mean"
)

# Can predict using default penalty. See #108
expect_error(
multi_predict(class_fit, new_data = wa_churn[1:4, vars]),
NA
)

})


0 comments on commit 01dc9bc

Please sign in to comment.