Skip to content

Commit

Permalink
add warmup information to sampler class and pass to print method
Browse files Browse the repository at this point in the history
  • Loading branch information
njtierney committed Dec 17, 2024
1 parent d475f88 commit b4df84a
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 2 deletions.
2 changes: 2 additions & 0 deletions R/greta_mcmc_list.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ window.greta_mcmc_list <- function(x, start, end, thin, ...) {
#' @export
print.greta_mcmc_list <- function(x, ..., n = 5){

n_warmup <- n_warmup(x)
n_chain <- coda::nchain(x)
n_iter <- coda::niter(x)
n_thin <- coda::thin(x)
cli::cli_h1("MCMC draws from {.pkg greta}")
cli::cli_bullets(
c(
"*" = "Iterations = {n_iter}",
"*" = "Warmup = {n_warmup}",
"*" = "Chains = {n_chain}",
"*" = "Thinning = {n_thin}"
)
Expand Down
8 changes: 6 additions & 2 deletions R/inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,10 @@ run_samplers <- function(samplers,
greta_stash$trace_log_files <- trace_log_files
greta_stash$percentage_log_files <- percentage_log_files
greta_stash$progress_bar_log_files <- progress_bar_log_files
greta_stash$mcmc_info <- list(n_samples = n_samples)
greta_stash$mcmc_info <- list(
n_samples = n_samples,
warmup = warmup
)
}

if (plan_is$parallel) {
Expand Down Expand Up @@ -514,7 +517,8 @@ stashed_samples <- function() {
model_info <- list(
raw_draws = free_state_draws,
samplers = samplers,
model = samplers[[1]]$model
model = samplers[[1]]$model,
warmup = samplers[[1]]$warmup
)

values_draws <- as_greta_mcmc_list(values_draws, model_info)
Expand Down
2 changes: 2 additions & 0 deletions R/inference_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ sampler <- R6Class(
n_chains = 1,
numerical_rejections = 0,
thin = 1,
warmup = 1,

# tuning information
mean_accept_stat = 0.5,
Expand Down Expand Up @@ -336,6 +337,7 @@ sampler <- R6Class(
one_by_one, plan_is, n_cores, float_type,
trace_batch_size,
from_scratch = TRUE) {
self$warmup <- warmup
self$thin <- thin
dag <- self$model$dag

Expand Down
5 changes: 5 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -1174,3 +1174,8 @@ are_initials <- function(x){
FUN.VALUE = logical(1)
)
}

n_warmup <- function(x){
x_info <- attr(x, "model_info")
x_info$warmup
}

0 comments on commit b4df84a

Please sign in to comment.