Skip to content

Fixing bug in heteroskedastic BCF and standardizing the use of variance instead of standard deviation throughout the interface #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 85 additions & 85 deletions R/bart.R

Large diffs are not rendered by default.

310 changes: 155 additions & 155 deletions R/bcf.R

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions R/kernel.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
if (!model_object$model_params$include_mean_forest) {
stop("Mean forest was not sampled in the bart model provided")
}
if (!model_object$model_params$sample_sigma_leaf) {
if (!model_object$model_params$sample_sigma2_leaf) {
stop("Leaf scale parameter was not sampled for the mean forest in the bart model provided")
}
leaf_scale_vector <- model_object$sigma2_leaf_samples
Expand All @@ -170,15 +170,15 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
} else {
stopifnot(forest_type %in% c("prognostic", "treatment", "variance"))
if (forest_type=="prognostic") {
if (!model_object$model_params$sample_sigma_leaf_mu) {
if (!model_object$model_params$sample_sigma2_leaf_mu) {
stop("Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided")
}
leaf_scale_vector <- model_object$sigma_leaf_mu_samples
leaf_scale_vector <- model_object$sigma2_leaf_mu_samples
} else if (forest_type=="treatment") {
if (!model_object$model_params$sample_sigma_leaf_tau) {
if (!model_object$model_params$sample_sigma2_leaf_tau) {
stop("Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided")
}
leaf_scale_vector <- model_object$sigma_leaf_tau_samples
leaf_scale_vector <- model_object$sigma2_leaf_tau_samples
} else if (forest_type=="variance") {
if (!model_object$model_params$include_variance_forest) {
stop("Variance forest was not sampled in the bcf model provided")
Expand Down
2 changes: 1 addition & 1 deletion R/serialization.R
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ loadRandomEffectSamplesCombinedJsonString <- function(json_string_list, json_rfx
#' Load a vector from json
#'
#' @param json_object Object of class `CppJson`
#' @param json_vector_label Label referring to a particular vector (i.e. "sigma2_samples") in the overall json hierarchy
#' @param json_vector_label Label referring to a particular vector (i.e. "sigma2_global_samples") in the overall json hierarchy
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which vector sits
#'
#' @return R vector
Expand Down
4 changes: 4 additions & 0 deletions R/stochtree-package.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
## usethis namespace: start
#' @importFrom stats coef
#' @importFrom stats dnorm
#' @importFrom stats lm
#' @importFrom stats model.matrix
#' @importFrom stats predict
#' @importFrom stats qgamma
#' @importFrom stats qnorm
#' @importFrom stats pnorm
#' @importFrom stats resid
#' @importFrom stats rnorm
#' @importFrom stats runif
#' @importFrom stats sd
#' @importFrom stats sigma
#' @importFrom stats var
Expand Down
4 changes: 2 additions & 2 deletions demo/debug/causal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2")
plt.show()

b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
Expand Down
12 changes: 6 additions & 6 deletions demo/debug/supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def outcome_mean(X, W):
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2")
plt.show()

# Compute the test set RMSE
Expand All @@ -89,8 +89,8 @@ def outcome_mean(X, W):
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2")
plt.show()

# Compute the test set RMSE
Expand All @@ -110,8 +110,8 @@ def outcome_mean(X, W):
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2")
plt.show()

# Compute the test set RMSE
Expand Down
36 changes: 18 additions & 18 deletions demo/notebooks/prototype_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,9 @@
"forest_preds_mcmc = forest_preds[:, num_warmstart:num_samples]\n",
"\n",
"# Global error variance\n",
"sigma_samples = np.sqrt(global_var_samples) * y_std\n",
"sigma_samples_gfr = sigma_samples[:num_warmstart]\n",
"sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]"
"sigma2_samples = global_var_samples * y_std * y_std\n",
"sigma2_samples_gfr = sigma2_samples[:num_warmstart]\n",
"sigma2_samples_mcmc = sigma2_samples[num_warmstart:num_samples]"
]
},
{
Expand Down Expand Up @@ -384,13 +384,13 @@
" np.concatenate(\n",
" (\n",
" np.expand_dims(np.arange(num_warmstart), axis=1),\n",
" np.expand_dims(sigma_samples_gfr, axis=1),\n",
" np.expand_dims(sigma2_samples_gfr, axis=1),\n",
" ),\n",
" axis=1,\n",
" ),\n",
" columns=[\"Sample\", \"Sigma\"],\n",
" columns=[\"Sample\", \"Sigma^2\"],\n",
")\n",
"sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma\")\n",
"sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma^2\")\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -427,13 +427,13 @@
" np.concatenate(\n",
" (\n",
" np.expand_dims(np.arange(num_samples - num_warmstart), axis=1),\n",
" np.expand_dims(sigma_samples_mcmc, axis=1),\n",
" np.expand_dims(sigma2_samples_mcmc, axis=1),\n",
" ),\n",
" axis=1,\n",
" ),\n",
" columns=[\"Sample\", \"Sigma\"],\n",
" columns=[\"Sample\", \"Sigma^2\"],\n",
")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -909,9 +909,9 @@
"forest_preds_tau_mcmc = forest_preds_tau[:, num_warmstart:num_samples]\n",
"\n",
"# Global error variance\n",
"sigma_samples = np.sqrt(global_var_samples) * y_std\n",
"sigma_samples_gfr = sigma_samples[:num_warmstart]\n",
"sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]\n",
"sigma2_samples = global_var_samples * y_std * y_std\n",
"sigma2_samples_gfr = sigma2_samples[:num_warmstart]\n",
"sigma2_samples_mcmc = sigma2_samples[num_warmstart:num_samples]\n",
"\n",
"# Adaptive coding parameters\n",
"b_1_samples_gfr = b_1_samples[:num_warmstart] * y_std\n",
Expand Down Expand Up @@ -969,13 +969,13 @@
" np.concatenate(\n",
" (\n",
" np.expand_dims(np.arange(num_warmstart), axis=1),\n",
" np.expand_dims(sigma_samples_gfr, axis=1),\n",
" np.expand_dims(sigma2_samples_gfr, axis=1),\n",
" ),\n",
" axis=1,\n",
" ),\n",
" columns=[\"Sample\", \"Sigma\"],\n",
" columns=[\"Sample\", \"Sigma^2\"],\n",
")\n",
"sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma\")\n",
"sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma^2\")\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -1050,13 +1050,13 @@
" np.concatenate(\n",
" (\n",
" np.expand_dims(np.arange(num_samples - num_warmstart), axis=1),\n",
" np.expand_dims(sigma_samples_mcmc, axis=1),\n",
" np.expand_dims(sigma2_samples_mcmc, axis=1),\n",
" ),\n",
" axis=1,\n",
" ),\n",
" columns=[\"Sample\", \"Sigma\"],\n",
" columns=[\"Sample\", \"Sigma^2\"],\n",
")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
"plt.show()"
]
},
Expand Down
4 changes: 2 additions & 2 deletions demo/notebooks/serialization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@
" ),\n",
" axis=1,\n",
" ),\n",
" columns=[\"Sample\", \"Sigma\"],\n",
" columns=[\"Sample\", \"Sigma^2\"],\n",
")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
"plt.show()"
]
},
Expand Down
16 changes: 8 additions & 8 deletions demo/notebooks/supervised_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@
" ),\n",
" axis=1,\n",
" ),\n",
" columns=[\"Sample\", \"Sigma\"],\n",
" columns=[\"Sample\", \"Sigma^2\"],\n",
")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -260,9 +260,9 @@
" ),\n",
" axis=1,\n",
" ),\n",
" columns=[\"Sample\", \"Sigma\"],\n",
" columns=[\"Sample\", \"Sigma^2\"],\n",
")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -346,9 +346,9 @@
" ),\n",
" axis=1,\n",
" ),\n",
" columns=[\"Sample\", \"Sigma\"],\n",
" columns=[\"Sample\", \"Sigma^2\"],\n",
")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
"plt.show()"
]
},
Expand All @@ -371,7 +371,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": "stochtree-dev",
"language": "python",
"name": "python3"
},
Expand All @@ -385,7 +385,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
"version": "3.10.16"
}
},
"nbformat": 4,
Expand Down
Loading
Loading