Skip to content

Commit

Permalink
document glm and glm_fast
Browse files Browse the repository at this point in the history
  • Loading branch information
nhejazi committed Oct 9, 2017
1 parent d474919 commit fb6c26a
Show file tree
Hide file tree
Showing 19 changed files with 389 additions and 169 deletions.
94 changes: 47 additions & 47 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,58 +1,58 @@
Package: sl3
Title: Machine Learning Pipelines
Version: 0.2
Date: 2017-08-01
Title: Pipelines for Machine Learning and Super Learning
Version: 0.2.1.0
Authors@R: c(
person("Jeremy", "Coyle", email = "[email protected]",
role = c("aut", "cre", "cph")),
person("Nima", "Hejazi", email = "[email protected]", role = "aut"),
person("Oleg", "Sofrygin", email="[email protected]", role = "aut"),
person("Ivana", "Malenica", email = "[email protected]", role = "aut")
person("Jeremy", "Coyle", email = "[email protected]",
role = c("aut", "cre", "cph")),
person("Nima", "Hejazi", email = "[email protected]", role = "aut"),
person("Oleg", "Sofrygin", email="[email protected]", role = "aut"),
person("Ivana", "Malenica", email = "[email protected]", role = "aut")
)
Maintainer: Jeremy Coyle <[email protected]>
Description: Implements the super learner prediction method and contains a
library of prediction algorithms to be used in the super learner.
License: GPL-3
URL: https://github.com/jeremyrcoyle/sl3
BugReports: https://github.com/jeremyrcoyle/sl3/issues
Depends:
R (>= 2.14.0)
library of prediction algorithms to be used in the super learner.
Depends: R (>= 2.14.0)
Imports:
data.table,
assertthat,
origami,
future,
R6,
uuid,
memoise,
digest,
BBmisc,
delayed,
methods,
stats
data.table,
assertthat,
origami,
future,
R6,
uuid,
memoise,
digest,
BBmisc,
delayed,
methods,
stats
Suggests:
testthat,
knitr,
rmarkdown,
glmnet,
devtools,
dplyr,
rgl,
Rsolnp,
condensier,
cvAUC,
h2o,
xgboost,
forecast,
nloptr,
rugarch,
nnls,
SuperLearner,
tsDyn,
randomForest
testthat,
knitr,
rmarkdown,
glmnet,
devtools,
dplyr,
rgl,
Rsolnp,
condensier,
cvAUC,
h2o,
xgboost,
forecast,
nloptr,
rugarch,
nnls,
SuperLearner,
tsDyn,
randomForest
Remotes:
github::osofr/condensier,
github::jeremyrcoyle/delayed
github::osofr/condensier,
github::jeremyrcoyle/delayed
License: GPL-3
URL: https://github.com/jeremyrcoyle/sl3
BugReports: https://github.com/jeremyrcoyle/sl3/issues
Encoding: UTF-8
LazyData: yes
LazyLoad: yes
VignetteBuilder: knitr
RoxygenNote: 6.0.1.9000
7 changes: 7 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export(Lrnr_h2o_mutator)
export(Lrnr_mean)
export(Lrnr_nnls)
export(Lrnr_pkg_SuperLearner)
export(Lrnr_pkg_SuperLearner_method)
export(Lrnr_pkg_SuperLearner_screener)
export(Lrnr_pkg_condensier_logisfitR6)
export(Lrnr_randomForest)
Expand Down Expand Up @@ -47,10 +48,16 @@ importFrom(R6,R6Class)
importFrom(assertthat,assert_that)
importFrom(assertthat,is.count)
importFrom(assertthat,is.flag)
importFrom(data.table,data.table)
importFrom(memoise,cache_filesystem)
importFrom(memoise,cache_memory)
importFrom(memoise,memoise)
importFrom(speedglm,speedglm.wfit)
importFrom(stats,arima)
importFrom(stats,family)
importFrom(stats,glm)
importFrom(stats,glm.fit)
importFrom(stats,predict)
importFrom(stats,weighted.mean)
importFrom(utils,packageVersion)
importFrom(utils,str)
Expand Down
85 changes: 50 additions & 35 deletions R/Lrnr_glm.R
Original file line number Diff line number Diff line change
@@ -1,39 +1,54 @@

#' GLM Fits
#'
#' This learner provides fitting procedures for generalized linear models by way
#' of a wrapper relying on \code{stats::glm}.
#'
#' @docType class
#'
#' @keywords data
#'
#' @return \code{\link{Lrnr_base}} object with methods for training and
#' prediction.
#'
#' @format \code{\link{R6Class}} object.
#'
#' @field family A \code{family} object from package \code{stats} describing the
#' error family of the model to be fit. See the documentation for the package
#' \code{stats} for details, or consult \code{stats::family} directly.
#' @field ... Additional arguments.
#'
#' @importFrom R6 R6Class
#' @importFrom stats glm predict family
#'
#' @export
#' @rdname undocumented_learner
Lrnr_glm <- R6Class(classname = "Lrnr_glm", inherit = Lrnr_base, portable = TRUE,
class = TRUE, public = list(
initialize = function(family = gaussian(), ...) {
params <- list(family = family, ...)
super$initialize(params = params, ...)
}),
private = list(
.train = function(task) {
params <- self$params
family <- params[["family"]]
if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
family <- family()
}
# todo: if possible have this use task$Xmat with glm.fit or speedglm
Y <- task$Y
fit_object <- glm(Y ~ ., data = task$X, family = family, weights = task$weights)
return(fit_object)

},
.predict = function(task = NULL) {
predictions <- predict(private$.fit_object, newdata = task$X, type = "response")
return(predictions)
}
),
#
Lrnr_glm <- R6Class(classname = "Lrnr_glm", inherit = Lrnr_base,
portable = TRUE, class = TRUE,
public = list(
initialize = function(family = gaussian(), ...) {
params <- list(family = family, ...)
super$initialize(params = params, ...)
}
),
private = list(
.train = function(task) {
params <- self$params
family <- params[["family"]]
if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
family <- stats::family()
}
# TODO: if possible have this use task$Xmat with glm.fit or speedglm
Y <- task$Y
fit_object <- stats::glm(Y ~ ., data = task$X, family = family,
weights = task$weights)
return(fit_object)
},
.predict = function(task = NULL) {
predictions <- stats::predict(private$.fit_object, newdata = task$X,
type = "response")
return(predictions)
}
),
)






90 changes: 52 additions & 38 deletions R/Lrnr_glm_fast.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
#if warning is in ignoreWarningList, ignore it; otherwise post it as usual
SuppressGivenWarnings <- function(expr, warningsToIgnore) {
h <- function (w) {
if (w$message %in% warningsToIgnore) invokeRestart( "muffleWarning" )
if (w$message %in% warningsToIgnore) invokeRestart("muffleWarning")
}
withCallingHandlers(expr, warning = h )
withCallingHandlers(expr, warning = h)
}

GetWarningsToSuppress <- function(update.step=FALSE) {
Expand All @@ -25,37 +25,47 @@ GetWarningsToSuppress <- function(update.step=FALSE) {
return(warnings.to.suppress)
}

## Define the design matrix for GLM regression. Returns a data.table.
## Allows to subset the taks$X data.table by a smaller set of covariates if spec'ed in params
## Can define interaction columns if spec'ed in params
# defineX <- function(task, params) {
# covariates <- task$nodes$covariates
# if ("covariates" %in% names(params) && !is.null(params[["covariates"]])) {
# covariates <- intersect(covariates, params$covariates)
# }
# X <- cbind(Intercept = 1L, task$X[,covariates, with=FALSE, drop=FALSE])
# if (!is.null(params[["interactions"]])) {
# # print("adding interactions in GLM:"); str(params[["interactions"]])
# ## this is a hack to fix pointer allocation problem (so that X can be modified by reference inside add_interactions_toDT())
# ## see this for more: http://stackoverflow.com/questions/28078640/adding-new-columns-to-a-data-table-by-reference-within-a-function-not-always-wor
# ## and this: http://stackoverflow.com/questions/36434717/adding-column-to-nested-r-data-table-by-reference
# data.table::setDF(X)
# data.table::setDT(X)
# add_interactions_toDT(X, params[["interactions"]])
# }
# return(X)
# }

#' Faster GLM Fits
#'
#' This learner provides faster fitting procedures for generalized linear models
#' by way of a wrapper relying on the \code{speedglm} package.
#'
#' @docType class
#'
#' @keywords data
#'
#' @return \code{\link{Lrnr_base}} object with methods for training and
#' prediction.
#'
#' @format \code{\link{R6Class}} object.
#'
#' @field family A \code{family} object from package \code{stats} describing the
#' error family of the model to be fit. See the documentation for the package
#' \code{speedglm} for details.
#' @field method The type of matrix decomposition to be used in the model
#' fitting process. See documentation for the package \code{speedglm} for
#' further details.
#' @field covariates Extra covariate terms to be passed to the model fitting
#' process. See documentation of the \code{speedglm} package for details.
#' @field ... Additional arguments.
#'
#' @importFrom R6 R6Class
#' @importFrom data.table data.table
#' @importFrom stats glm.fit
#' @importFrom speedglm speedglm.wfit
#' @importFrom assertthat assert_that is.count is.flag
#'
#' @export
#' @rdname undocumented_learner
Lrnr_glm_fast <- R6Class(classname = "Lrnr_glm_fast", inherit = Lrnr_base, portable = TRUE, class = TRUE,
#
Lrnr_glm_fast <- R6Class(classname = "Lrnr_glm_fast", inherit = Lrnr_base,
portable = TRUE, class = TRUE,
public = list(
initialize = function(family = gaussian(),
method = c('Cholesky', 'eigen','qr'),
covariates = NULL,
...) {
params <- list(family = family, method = method[1L], covariates = covariates, ...)
params <- list(family = family, method = method[1L],
covariates = covariates, ...)
super$initialize(params = params, ...)
}
),
Expand All @@ -75,17 +85,20 @@ Lrnr_glm_fast <- R6Class(classname = "Lrnr_glm_fast", inherit = Lrnr_base, porta

SuppressGivenWarnings({
fit_object <- try(speedglm::speedglm.wfit(X = as.matrix(X),
y = task$Y,
method = method,
family = family,
trace = FALSE,
weights = task$weights),
y = task$Y,
method = method,
family = family,
trace = FALSE,
weights = task$weights),
silent = TRUE)
}, GetWarningsToSuppress())

if (inherits(fit_object, "try-error")) { # if failed, fall back on stats::glm
## todo: find example where speedglm fails, and this code runs, add to tests
if (verbose) message("speedglm::speedglm.wfit failed, falling back on stats:glm.fit; ", fit_object)
if (inherits(fit_object, "try-error")) {
# if failed, fall back on stats::glm
## TODO: find example where speedglm fails and this runs, add to tests
if (verbose) {
message("speedglm::speedglm.wfit failed, falling back on stats:glm.fit; ", fit_object)
}
ctrl <- glm.control(trace = FALSE)
SuppressGivenWarnings({
fit_object <- stats::glm.fit(x = X,
Expand All @@ -103,7 +116,6 @@ Lrnr_glm_fast <- R6Class(classname = "Lrnr_glm_fast", inherit = Lrnr_base, porta
fit_object$effects <- NULL
fit_object$qr <- NULL
}

fit_object[["linkinv_fun"]] <- linkinv_fun
return(fit_object)
},
Expand All @@ -115,11 +127,13 @@ Lrnr_glm_fast <- R6Class(classname = "Lrnr_glm_fast", inherit = Lrnr_base, porta
if (nrow(X) > 0) {
coef <- private$.fit_object$coef
if (!all(is.na(coef))) {
eta <- as.matrix(X[, which(!is.na(coef)), drop = FALSE, with = FALSE]) %*% coef[!is.na(coef)]
eta <- as.matrix(X[, which(!is.na(coef)), drop = FALSE,
with = FALSE]) %*% coef[!is.na(coef)]
predictions <- as.vector(private$.fit_object$linkinv_fun(eta))
}
}
return(data.table::data.table(predictions))
},
},
.required_packages = c("speedglm")
), )
), )

Loading

0 comments on commit fb6c26a

Please sign in to comment.