Skip to content

Fixing bugs in BCF predict and sampler reload functionality #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
global_model_config$update_global_error_variance(current_sigma2)
}
} else if (has_prev_model) {
if (adaptive_coding) {
if (!is.null(previous_b_1_samples)) {
current_b_1 <- previous_b_1_samples[previous_model_warmstart_sample_num]
}
if (!is.null(previous_b_0_samples)) {
current_b_0 <- previous_b_0_samples[previous_model_warmstart_sample_num]
}
tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1
forest_dataset_train$update_basis(tau_basis_train)
if (has_test) {
tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1
forest_dataset_test$update_basis(tau_basis_test)
}
forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau)
}
resetActiveForest(active_forest_mu, previous_forest_samples_mu, previous_model_warmstart_sample_num - 1)
resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE)
resetActiveForest(active_forest_tau, previous_forest_samples_tau, previous_model_warmstart_sample_num - 1)
Expand All @@ -1122,21 +1137,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double)
forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau)
}
if (adaptive_coding) {
if (!is.null(previous_b_1_samples)) {
current_b_1 <- previous_b_1_samples[previous_model_warmstart_sample_num]
}
if (!is.null(previous_b_0_samples)) {
current_b_0 <- previous_b_0_samples[previous_model_warmstart_sample_num]
}
tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1
forest_dataset_train$update_basis(tau_basis_train)
if (has_test) {
tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1
forest_dataset_test$update_basis(tau_basis_test)
}
forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau)
}
if (has_rfx) {
if (is.null(previous_rfx_samples)) {
warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started")
Expand Down Expand Up @@ -1618,6 +1618,8 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
# Add propensities to covariate set if necessary
if (object$model_params$propensity_covariate != "none") {
X_combined <- cbind(X, propensity)
} else {
X_combined <- X
}

# Create prediction datasets
Expand Down
84 changes: 84 additions & 0 deletions tools/debug/bart_continue_sampler_debug.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Load libraries
library(stochtree)

# Sampler settings
num_chains <- 1
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 20
num_trees <- 100

# Generate the data
n <- 500
p_x <- 10
snr <- 2
X <- matrix(runif(n*p_x), ncol = p_x)
f_XW <- sin(4*pi*X[,1]) + cos(4*pi*X[,2]) + sin(4*pi*X[,3]) +cos(4*pi*X[,4])
noise_sd <- sd(f_XW) / snr
y <- f_XW + rnorm(n, 0, 1)*noise_sd

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- as.data.frame(X[test_inds,])
X_train <- as.data.frame(X[train_inds,])
y_test <- y[test_inds]
y_train <- y[train_inds]
f_XW_test <- f_XW[test_inds]
f_XW_train <- f_XW[train_inds]

# Run the GFR algorithm
general_params <- list(sample_sigma2_global = T)
mean_forest_params <- list(num_trees = num_trees, alpha = 0.95,
beta = 2.0, max_depth = -1,
min_samples_leaf = 1,
sample_sigma2_leaf = F,
sigma2_leaf_init = 1.0/num_trees)
xbart_model <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0,
general_params = general_params,
mean_forest_params = mean_forest_params
)

# Inspect results
plot(rowMeans(xbart_model$y_hat_test), y_test); abline(0,1)
cat(paste0("RMSE = ", sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), "\n"))
cat(paste0("Interval coverage = ", mean((apply(xbart_model$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(xbart_model$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n"))
plot(xbart_model$sigma2_global_samples)
xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model)

# Run the BART MCMC sampler, initialized from the XBART sampler
general_params <- list(sample_sigma2_global = T)
mean_forest_params <- list(num_trees = num_trees, alpha = 0.95,
beta = 2.0, max_depth = -1,
min_samples_leaf = 1,
sample_sigma2_leaf = F,
sigma2_leaf_init = 1.0/num_trees)
bart_model <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params,
previous_model_json = xbart_model_string,
previous_model_warmstart_sample_num = num_gfr
)

# Inspect the results
plot(rowMeans(bart_model$y_hat_test), y_test); abline(0,1)
cat(paste0("RMSE = ", sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), "\n"))
cat(paste0("Interval coverage = ", mean((apply(bart_model$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(bart_model$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n"))
plot(bart_model$sigma2_global_samples)

# Compare to a single chain of MCMC samples initialized at root
bart_model_root <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params
)
plot(rowMeans(bart_model_root$y_hat_test), y_test); abline(0,1)
cat(paste0("RMSE = ", sqrt(mean((rowMeans(bart_model_root$y_hat_test) - y_test)^2)), "\n"))
cat(paste0("Interval coverage = ", mean((apply(bart_model_root$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(bart_model_root$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n"))
plot(bart_model_root$sigma2_global_samples)
210 changes: 210 additions & 0 deletions tools/debug/bcf_401k_data_debug.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
################################################################################
## Investigation of GFR vs MCMC fit issues on the 401k dataset
################################################################################

# Load libraries and set seed
library(stochtree)
library(DoubleML)
library(BART)
library(tidyverse)
# seed = 102
# set.seed(seed)

# Load 401k data
dat = DoubleML::fetch_401k(return_type = "data.frame")
dat_orig = dat

# Trim outliers
dat = dat %>% filter(abs(inc)<quantile(abs(inc), 0.9))

# Isolate covariates and convert to df
x = dat %>% dplyr::select(-c(e401, net_tfa))

# Convert to df and define categorical data types
xdf = data.frame(x)
xdf_st = xdf %>%
mutate(age=factor(age, ordered=TRUE),
inc = factor(inc, ordered=TRUE),
educ = factor(educ, ordered=TRUE),
fsize = factor(fsize, ordered=TRUE),
marr=factor(marr, ordered=TRUE),
twoearn=factor(twoearn, ordered=TRUE),
db=factor(db, ordered=TRUE),
pira=factor(pira, ordered=TRUE),
hown=factor(hown, ordered=TRUE))

# Isolate treatment and outcome
z = dat %>% dplyr::select(e401) %>% as.matrix()
y = dat %>% dplyr::select(net_tfa) %>% as.matrix()

# Define a "jittered" version of the original (integer-valued) x columns
# in which all categories are "upper-jittered" with uniform [0, eps] noise
# except for the largest category which is "lower-jittered" with [-eps, 0] noise
x_jitter = x
for (j in 1:ncol(x)) {
min_diff <- min(diff(sort(x[,j]))[diff(sort(x[,j])) > 0])
jitter_param <- min_diff / 3.0
has_max_category <- x[,j] == max(x[,j])
x_jitter[has_max_category,j] <- x[has_max_category,j] + runif(sum(has_max_category), -jitter_param, 0.0)
x_jitter[!has_max_category,j] <- x[!has_max_category,j] + runif(sum(!has_max_category), 0.0, jitter_param)
}
# Visualize jitters
# for (j in 1:ncol(x)) {
# plot(x[,j], x_jitter[,j], ylab = "jittered", xlab = "original")
# unique_xs <- unique(x[,j])
# for (i in unique_xs) {
# abline(h = unique_xs[i], col = "red", lty = 3)
# }
# }

# Fit a p(z = 1 | x) model for propensity features
ps_fit = pbart(x.train = xdf,
y.train = z, ntree = 200, numcut=1000, ndpost = 100,
usequants = TRUE, k = 2.0, nskip = 100, keepevery=1)
g = colMeans(pnorm(ps_fit$yhat.train))
psf = pnorm(ps_fit$yhat.train)

# Test-train split
n <- nrow(x)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
xdf_st_test <- xdf_st[test_inds,]
xdf_st_train <- xdf_st[train_inds,]
x_test <- x[test_inds,]
x_train <- x[train_inds,]
x_jitter_test <- x_jitter[test_inds,]
x_jitter_train <- x_jitter[train_inds,]
pi_test <- g[test_inds]
pi_train <- g[train_inds]
z_test <- z[test_inds,]
z_train <- z[train_inds,]
y_test <- y[test_inds,]
y_train <- y[train_inds,]
y_train_scale <- scale(y_train)
y_train_sd <- attr(y_train_scale, "scaled:scale")
y_train_mean <- attr(y_train_scale, "scaled:center")
y_test_scale <- (y_test - y_train_mean) / y_train_sd
var(y_train_scale)
var(y_test_scale)

# Fit BCF with GFR algorithm on the jittered covariates
# and save model to JSON
num_gfr <- 1000
general_params <- list(
adaptive_coding = FALSE, propensity_covariate = "none",
keep_every = 1, verbose = TRUE, keep_gfr = TRUE
)
bcf_model_gfr <- stochtree::bcf(
X_train = xdf_st_train, Z_train = c(z_train),
y_train = c(y_train_scale), propensity_train = pi_train,
X_test = xdf_st_test, Z_test = c(z_test),
propensity_test = pi_test, num_gfr = num_gfr, num_burnin = 0,
num_mcmc = 0, general_params = general_params
)
fit_json_gfr = saveBCFModelToJsonString(bcf_model_gfr)

# Run MCMC chain from the last GFR sample, setting covariate
# equal to an interpolation between the original x and x_jitter
# (alpha = 0 is 100% x_jitter and alpha = 1 is 100% x)
# alpha <- 1.0
# x_jitter_new_train <- (alpha) * x_train + (1-alpha) * x_jitter_train
# x_jitter_new_test <- (alpha) * x_test + (1-alpha) * x_jitter_test
x_jitter_new_train <- xdf_st_train
x_jitter_new_test <- xdf_st_test
num_mcmc <- 10000
bcf_model_mcmc <- stochtree::bcf(
X_train = x_jitter_new_train, Z_train = c(z_train),
y_train = c(y_train_scale), propensity_train = pi_train,
X_test = x_jitter_new_test, Z_test = c(z_test),
propensity_test = pi_test,
num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc,
previous_model_json = fit_json_gfr,
previous_model_warmstart_sample_num = num_gfr,
general_params = general_params
)

# Inspect the "in-sample sigma" via the traceplot
# of the global error variance parameter
combined_sigma <- c(bcf_model_gfr$sigma2_global_samples,
bcf_model_mcmc$sigma2_global_samples)
plot(combined_sigma, ylab = "sigma2", xlab = "sample num",
main = "Global error var traceplot")

# Inspect the "out-of-sample sigma" by compute the MSE
# of the yhat on the test set
yhat_combined_train <- cbind(
bcf_model_gfr$y_hat_train,
bcf_model_mcmc$y_hat_train
)
yhat_combined_test <- cbind(
bcf_model_gfr$y_hat_test,
bcf_model_mcmc$y_hat_test
)
num_samples <- ncol(yhat_combined_train)
train_mses <- rep(NA, num_samples)
for (i in 1:num_samples) {
train_mses[i] <- mean((yhat_combined_train[,i] - y_train_scale)^2)
}
test_mses <- rep(NA, num_samples)
for (i in 1:num_samples) {
test_mses[i] <- mean((yhat_combined_test[,i] - y_test_scale)^2)
}
max_y <- max(c(max(train_mses, test_mses)))
min_y <- min(c(min(train_mses, test_mses)))
plot(test_mses, ylab = "outcome MSE", xlab = "sample num",
main = "Outcome MSE Traceplot", ylim = c(min_y, max_y))
points(train_mses, col = "blue")
legend("right", legend = c("Out-of-Sample", "In-Sample"),
col = c("black", "blue"), pch = c(1,1))

# Run some one-off pred vs actual plots
plot(yhat_combined[,11000], y_test_scale); abline(0,1,col="red",lty=3)
plot(bcf_model_mcmc$y_hat_train[,10000], y_train_scale); abline(0,1,col="red",lty=3)
plot(bcf_model_mcmc$y_hat_test[,10000], y_test_scale); abline(0,1,col="red",lty=3)
plot(bcf_model_gfr$y_hat_train[,1000], y_train_scale); abline(0,1,col="red",lty=3)
plot(bcf_model_gfr$y_hat_test[,1000], y_test_scale); abline(0,1,col="red",lty=3)
plot(bcf_model_gfr$y_hat_train[,10], y_train_scale); abline(0,1,col="red",lty=3)
plot(bcf_model_gfr$y_hat_test[,10], y_test_scale); abline(0,1,col="red",lty=3)

# Run MCMC chain from root
num_mcmc <- 10000
bcf_model_mcmc_root <- stochtree::bcf(
X_train = xdf_st_train, Z_train = c(z_train),
y_train = c(y_train_scale), propensity_train = pi_train,
X_test = xdf_st_test, Z_test = c(z_test),
propensity_test = pi_test,
num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc,
general_params = general_params
)

# Inspect the "in-sample sigma" via the traceplot
# of the global error variance parameter
sigma_trace <- bcf_model_mcmc_root$sigma2_global_samples
plot(sigma_trace, ylab = "sigma2", xlab = "sample num",
main = "Global error var traceplot")

# Inspect the "out-of-sample sigma" by compute the MSE
# of the yhat on the test set
yhat_combined_train <- cbind(
bcf_model_mcmc_root$y_hat_train
)
yhat_combined_test <- cbind(
bcf_model_mcmc_root$y_hat_test
)
num_samples <- ncol(yhat_combined_train)
train_mses <- rep(NA, num_samples)
for (i in 1:num_samples) {
train_mses[i] <- mean((yhat_combined_train[,i] - y_train_scale)^2)
}
test_mses <- rep(NA, num_samples)
for (i in 1:num_samples) {
test_mses[i] <- mean((yhat_combined_test[,i] - y_test_scale)^2)
}
max_y <- max(c(max(train_mses, test_mses)))
min_y <- min(c(min(train_mses, test_mses)))
plot(test_mses, ylab = "outcome MSE", xlab = "sample num",
main = "Test set outcome MSEs", ylim = c(min_y, max_y))
points(train_mses, col = "blue")
Loading