Skip to content

Commit

Permalink
Add dbarts
Browse files Browse the repository at this point in the history
  • Loading branch information
imalenica committed May 6, 2018
1 parent d5219ea commit f86c053
Show file tree
Hide file tree
Showing 47 changed files with 566 additions and 91 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ Suggests:
Remotes:
github::tlverse/delayed,
github::osofr/condensier,
github::jeremyrcoyle/hal9001
github::jeremyrcoyle/hal9001,
github::vdorie/dbarts
License: GPL-3
URL: https://sl3.tlverse.org
BugReports: https://github.com/tlverse/sl3/issues
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export(Lrnr_bilstm)
export(Lrnr_condensier)
export(Lrnr_cv)
export(Lrnr_dBart)
export(Lrnr_dbarts)
export(Lrnr_define_interactions)
export(Lrnr_expSmooth)
export(Lrnr_glm)
Expand Down
106 changes: 106 additions & 0 deletions R/Lrnr_dBart.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#' Discrete Bayesian Additive Regression Tree sampler
#'
#' This learner implements BART algorithm in C++, using the \code{dbarts} package.
#'
#' @docType class
#'
#' @importFrom R6 R6Class
#' @importFrom stats predict
#' @importFrom assertthat assert_that is.count is.flag
#'
#' @export
#'
#' @keywords data
#'
#' @return Learner object with methods for training and prediction. See
#' \code{\link{Lrnr_base}} for documentation on learners.
#'
#' @format \code{\link{R6Class}} object.
#'
#' @family Learners
#'
#' @section Parameters:
#' \describe{
#' \item{\code{Y}}{Outcome variable}
#' \item{\code{X}}{Covariate dataframe}
#' \item{\code{test}}{An optional matrix or data frame with the same number of
#' predictors as data, or formula in backwards compatibility mode. If column names
#' are present, a matching algorithm is used.}
#' \item{\code{subset}}{An optional vector specifying a subset of observations to
#' be used in the fitting process.}
#' \item{\code{weights}}{An optional vector of weights to be used in the fitting
#' process. When present, BART fits a model with observations y | x ~ N(f(x), σ^2 / w),
#' where f(x) is the unknown function.}
#' \item{\code{offset}}{An optional vector specifying an offset from 0 for the
#' relationship between the underyling function, f(x), and the response y.
#' Only is useful for binary responses, in which case the model fit is to assume
#' P(Y = 1 | X = x) = Φ(f(x) + offset), where Φ is the standard normal cumulative
#' distribution function.}
#' \item{\code{offset.test}}{The equivalent of offset for test observations.
#' Will attempt to use offset when applicable.}
#' \item{\code{verbose}}{A logical determining if additional output is printed to
#' the console. See dbartsControl.}
#' \item{\code{n.samples}}{A positive integer setting the default number of
#' posterior samples to be returned for each run of the sampler.
#' Can be overriden at run-time. See dbartsControl.}
#' \item{\code{tree.prior}}{An expression of the form cgm or cgm(power, base)
#' setting the tree prior used in fitting.}
#' \item{\code{node.prior}}{An expression of the form normal or normal(k) that
#' sets the prior used on the averages within nodes.}
#' \item{\code{resid.prior}}{An expression of the form chisq or chisq(df, quant)
#' that sets the prior used on the residual/error variance.}
#' \item{\code{control}}{An object inheriting from dbartsControl, created by the dbartsControl function.}
#' \item{\code{sigma}}{A positive numeric estimate of the residual standard deviation.
#' If NA, a linear model is used with all of the predictors to obtain one.}
#' }
#'
#' @template common_parameters
#

Lrnr_dBart <- R6Class(
classname = "Lrnr_dBart",
inherit = Lrnr_base, portable = TRUE, class = TRUE,
public = list(
initialize = function(offset.test = offset, verbose = FALSE, n.samples = 800L,
tree.prior = cgm, ode.prior = normal, resid.prior = chisq,
control = dbartsControl(), sigma = NA_real_, ...) {
super$initialize(params = args_to_list(), ...)
}
),

private = list(
.properties = c("continuous", "binomial", "categorical", "weights"),

.train = function(task) {
args <- self$params
outcome_type <- self$get_outcome_type(task)

# specify data
args$X <- as.data.frame(task$X)
args$y <- outcome_type$format(task$Y)

if (task$has_node("weights")) {
args$weights <- task$weights
}

if (task$has_node("offset")) {
args$offset <- task$offset
}

fit_object <- call_with_args(dbarts::bart, args)

return(fit_object)
},

.predict = function(task) {
# outcome_type <- private$.training_outcome_type
predictions <- stats::predict(
private$.fit_object,
new_data = data.frame(task$X)
)

return(predictions)
},
.required_packages = c("rJava", "dbart")
)
)
127 changes: 127 additions & 0 deletions R/Lrnr_dbarts.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#' Discrete Bayesian Additive Regression Tree sampler
#'
#' This learner implements BART algorithm in C++, using the \code{dbarts} package.
#' BART is a Bayesian sum-of-trees model in which each tree is constrained
#' by a prior to be a weak learner.
#'
#' @docType class
#'
#' @importFrom R6 R6Class
#' @importFrom stats predict
#' @importFrom assertthat assert_that is.count is.flag
#'
#' @export
#'
#' @keywords data
#'
#' @return Learner object with methods for training and prediction. See
#' \code{\link{Lrnr_base}} for documentation on learners.
#'
#' @format \code{\link{R6Class}} object.
#'
#' @family Learners
#'
#' @section Parameters:
#' \describe{
#' \item{\code{x.test}}{Explanatory variables for test (out of sample) data.
#' \code{bart} will generate draws of \eqn{f(x)} for each \eqn{x} which is a row of \code{x.test}.}
#' \item{\code{sigest}}{For continuous response models, an estimate of the error variance, \eqn{\sigma^2},
#' used to calibrate an inverse-chi-squared prior used on that parameter. If not supplied,
#' the least-squares estimate is derived instead. See \code{sigquant} for more information.
#' Not applicable when \eqn{y} is binary.}
#' \item{\code{sigdf}}{Degrees of freedom for error variance prior.
#' Not applicable when \eqn{y} is binary.}
#' \item{\code{sigquant}}{The quantile of the error variance prior that the rough estimate
#' (\code{sigest}) is placed at. The closer the quantile is to 1, the more aggresive the fit
#' will be as you are putting more prior weight on error standard deviations (\eqn{\sigma})
#' less than the rough estimate. Not applicable when \eqn{y} is binary.}
#' \item{\code{k}}{For numeric \eqn{y}, \code{k} is the number of prior standard deviations
#' \eqn{E(Y|x) = f(x)} is away from \eqn{\pm 0.5}{+/- 0.5}. The response (\code{y.train}) is
#' internally scaled to range from \eqn{-0.5} to \eqn{0.5}. For binary \eqn{y}, \code{k} is
#' the number of prior standard deviations \eqn{f(x)} is away from \eqn{\pm 3}{+/- 3}.
#' In both cases, the bigger \eqn{k} is, the more conservative the fitting will be.}
#' \item{\code{power}}{Power parameter for tree prior.}
#' \item{\code{base}}{Base parameter for tree prior.}
#' \item{\code{binaryOffset}}{ sed for binary \eqn{y}. When present, the model is
#' \eqn{P(Y = 1 \mid x) = \Phi(f(x) + \mathrm{binaryOffset})}{P(Y = 1 | x) = \Phi(f(x) + binaryOffset)},
#' allowing fits with probabilities shrunk towards values other than \eqn{0.5}.}
#' \item{\code{weights}}{An optional vector of weights to be used in the fitting process.
#' When present, BART fits a model with observations \eqn{y \mid x \sim N(f(x),
#' \sigma^2 / w)}{y | x ~ N(f(x), \sigma^2 / w)}, where \eqn{f(x)} is the unknown function.}
#' \item{\code{ntree}}{The number of trees in the sum-of-trees formulation.}
#' \item{\code{ndpost}}{The number of posterior draws after burn in,
#' \code{ndpost / keepevery} will actually be returned.}
#' \item{\code{nskip}}{Number of MCMC iterations to be treated as burn in.}
#' \item{\code{printevery}}{As the MCMC runs, a message is printed every \code{printevery} draws.}
#' \item{\code{keepevery}}{Every \code{keepevery} draw is kept to be returned to the user.
#' Useful for \dQuote{thinning} samples.}
#' \item{\code{keeptrainfits}}{If \code{TRUE} the draws of \eqn{f(x)} for \eqn{x} corresponding
#' to the rows of \code{x.train} are returned.}
#' \item{\code{usequants}}{When \code{TRUE}, determine tree decision rules using estimated
#' quantiles derived from the \code{x.train} variables. When \code{FALSE}, splits are
#' determined using values equally spaced across the range of a variable. See details for more information.}
#' \item{\code{numcut}}{The maximum number of possible values used in decision rules (see \code{usequants}, details).
#' If a single number, it is recycled for all variables; otherwise must be a vector of length
#' equal to \code{ncol(x.train)}. Fewer rules may be used if a covariate lacks enough unique values.}
#' \item{\code{printcutoffs}}{The number of cutoff rules to printed to screen before the MCMC is run.
#' Given a single integer, the same value will be used for all variables. If 0, nothing is printed.}
#' \item{\code{verbose}}{Logical; if \code{FALSE} supress printing.}
#' \item{\code{nchain}}{Integer specifying how many independent tree sets and fits should be calculated.}
#' \item{\code{nthread}}{Integer specifying how many threads to use. Depending on the CPU architecture,
#' using more than the number of chains can degrade performance for small/medium data sets.
#' As such some calculations may be executed single threaded regardless.}
#' \item{\code{combinechains}}{Logical; if \code{TRUE}, samples will be returned in arrays of
#' dimensions equal to \code{nchain} \eqn{\times} \code{ndpost} \eqn{\times} number of observations.}
#' \item{\code{keeptrees}}{Logical; must be \code{TRUE} in order to use \code{predict} with the result of a \code{bart} fit.}
#' \item{\code{keepcall}}{Logical; if \code{FALSE}, returned object will have \code{call} set to
#' \code{call("NULL")}, otherwise the call used to instantiate BART.}
#' }
#'
#' @template common_parameters
#

Lrnr_dbarts <- R6Class(
classname = "Lrnr_dbarts",
inherit = Lrnr_base, portable = TRUE, class = TRUE,
public = list(
initialize = function(ndpost = 20, nskip = 5,
ntree = 5L, verbose = FALSE, keeptrees = TRUE, ...) {
super$initialize(params = args_to_list(), ...)
}
),

private = list(
.properties = c("continuous", "binomial", "categorical", "weights"),

.train = function(task) {
args <- self$params
outcome_type <- self$get_outcome_type(task)

# specify data
args$x.train <- as.data.frame(task$X)
args$y.train <- outcome_type$format(task$Y)

if (task$has_node("weights")) {
args$weights <- task$weights
}

if (task$has_node("offset")) {
args$offset <- task$offset
}

fit_object <- call_with_args(dbarts::bart, args)

return(fit_object)
},

.predict = function(task) {
predictions <- stats::predict(
private$.fit_object,
data.frame(task$X)
)

return(predictions)
},
.required_packages = c("dbarts")
)
)
2 changes: 1 addition & 1 deletion man/Custom_chain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/Lrnr_HarmonicReg.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/Lrnr_arima.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/Lrnr_bartMachine.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/Lrnr_base.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/Lrnr_bilstm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/Lrnr_condensier.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/Lrnr_cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit f86c053

Please sign in to comment.