-
-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathdiagnostics.R
439 lines (412 loc) · 15.6 KB
/
diagnostics.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
#' Diagnostics for Pareto smoothed importance sampling (PSIS)
#'
#' Print a diagnostic table summarizing the estimated Pareto shape parameters
#' and PSIS effective sample sizes, find the indexes of observations for which
#' the estimated Pareto shape parameter \eqn{k} is larger than some
#' `threshold` value, or plot observation indexes vs. diagnostic estimates.
#' The **Details** section below provides a brief overview of the
#' diagnostics, but we recommend consulting Vehtari, Gelman, and Gabry (2017)
#' and Vehtari, Simpson, Gelman, Yao, and Gabry (2024) for full details.
#'
#' @name pareto-k-diagnostic
#' @param x An object created by [loo()] or [psis()].
#' @param threshold For `pareto_k_ids()`, `threshold` is the minimum \eqn{k}
#' value to flag (default is a sample size `S` dependend threshold
#' `1 - 1 / log10(S)`). For `mcse_loo()`, if any \eqn{k} estimates are
#' greater than `threshold` the MCSE estimate is returned as `NA`
#' See **Details** for the motivation behind these defaults.
#'
#' @details
#'
#' The reliability and approximate convergence rate of the PSIS-based
#' estimates can be assessed using the estimates for the shape
#' parameter \eqn{k} of the generalized Pareto distribution. The
#' diagnostic threshold for Pareto \eqn{k} depends on sample size
#' \eqn{S} (sample size dependent threshold was introduced by Vehtari
#' et al. (2024), and before that fixed thresholds of 0.5 and 0.7 were
#' recommended). For simplicity, `loo` package uses the nominal sample
#' size \eqn{S} when computing the sample size specific
#' threshold. This provides an optimistic threshold if the effective
#' sample size is less than 2200, but if MCMC-ESS > S/2 the difference
#' is usually negligible. Thinning of MCMC draws can be used to
#' improve the ratio ESS/S.
#'
#' * If \eqn{k < min(1 - 1 / log10(S), 0.7)}, where \eqn{S} is the
#' sample size, the PSIS estimate and the corresponding Monte Carlo
#' standard error estimate are reliable.
#'
#' * If \eqn{1 - 1 / log10(S) <= k < 0.7}, the PSIS estimate and the
#' corresponding Monte Carlo standard error estimate are not
#' reliable, but increasing the (effective) sample size \eqn{S} above
#' 2200 may help (this will increase the sample size specific
#' threshold \eqn{(1-1/log10(2200)>0.7} and then the bias specific
#' threshold 0.7 dominates).
#'
#' * If \eqn{0.7 <= k < 1}, the PSIS estimate and the corresponding Monte
#' Carlo standard error have large bias and are not reliable. Increasing
#' the sample size may reduce the variability in \eqn{k} estimate, which
#' may result in lower \eqn{k} estimate, too.
#'
#' * If \eqn{k \geq 1}{k >= 1}, the target distribution is estimated to
#' have a non-finite mean. The PSIS estimate and the corresponding Monte
#' Carlo standard error are not well defined. Increasing the sample size
#' may reduce the variability in the \eqn{k} estimate, which
#' may also result in a lower \eqn{k} estimate.
#'
#' \subsection{What if the estimated tail shape parameter \eqn{k}
#' exceeds the diagnostic threshold?}{ Importance sampling is likely to
#' work less well if the marginal posterior \eqn{p(\theta^s | y)} and
#' LOO posterior \eqn{p(\theta^s | y_{-i})} are very different, which
#' is more likely to happen with a non-robust model and highly
#' influential observations. If the estimated tail shape parameter
#' \eqn{k} exceeds the diagnostic threshold, the user should be
#' warned. (Note: If \eqn{k} is greater than the diagnostic threshold
#' then WAIC is also likely to fail, but WAIC lacks as accurate
#' diagnostic.) When using PSIS in the context of approximate LOO-CV,
#' we recommend one of the following actions:
#'
#' * With some additional computations, it is possible to transform
#' the MCMC draws from the posterior distribution to obtain more
#' reliable importance sampling estimates. This results in a smaller
#' shape parameter \eqn{k}. See [loo_moment_match()] and the
#' vignette *Avoiding model refits in leave-one-out cross-validation
#' with moment matching* for an example of this.
#'
#' * Sampling from a leave-one-out mixture distribution (see the
#' vignette *Mixture IS leave-one-out cross-validation for
#' high-dimensional Bayesian models*), directly from \eqn{p(\theta^s
#' | y_{-i})} for the problematic observations \eqn{i}, or using
#' \eqn{K}-fold cross-validation (see the vignette *Holdout
#' validation and K-fold cross-validation of Stan programs with the
#' loo package*) will generally be more stable.
#'
#' * Using a model that is more robust to anomalous observations will
#' generally make approximate LOO-CV more stable.
#'
#' }
#'
#' \subsection{Observation influence statistics}{ The estimated shape parameter
#' \eqn{k} for each observation can be used as a measure of the observation's
#' influence on posterior distribution of the model. These can be obtained with
#' `pareto_k_influence_values()`.
#' }
#'
#' \subsection{Effective sample size and error estimates}{ In the case that we
#' obtain the samples from the proposal distribution via MCMC the **loo**
#' package also computes estimates for the Monte Carlo error and the effective
#' sample size for importance sampling, which are more accurate for PSIS than
#' for IS and TIS (see Vehtari et al (2024) for details). However, the PSIS
#' effective sample size estimate will be
#' **over-optimistic when the estimate of \eqn{k} is greater than**
#' \eqn{min(1-1/log10(S), 0.7)}, where \eqn{S} is the sample size.
#' }
#'
#' @seealso
#' * [psis()] for the implementation of the PSIS algorithm.
#' * The [FAQ page](https://mc-stan.org/loo/articles/online-only/faq.html) on
#' the __loo__ website for answers to frequently asked questions.
#'
#' @template loo-and-psis-references
#'
NULL
#' @rdname pareto-k-diagnostic
#' @export
#' @return `pareto_k_table()` returns an object of class
#' `"pareto_k_table"`, which is a matrix with columns `"Count"`,
#' `"Proportion"`, and `"Min. n_eff"`, and has its own print method.
#'
pareto_k_table <- function(x) {
k <- pareto_k_values(x)
n_eff <- try(psis_n_eff_values(x), silent = TRUE)
if (inherits(n_eff, "try-error")) {
n_eff <- rep(NA, length(k))
}
S <- dim(x)[1]
k_threshold <- ps_khat_threshold(S)
kcut <- k_cut(k, k_threshold)
n_eff[k>k_threshold] <- NA
min_n_eff <- min_n_eff_by_k(n_eff, kcut)
count <- table(kcut)
out <- cbind(
Count = count,
Proportion = prop.table(count),
"Min. n_eff" = min_n_eff
)
attr(out, "k_threshold") <- k_threshold
structure(out, class = c("pareto_k_table", class(out)))
}
#' @export
print.pareto_k_table <- function(x, digits = 1, ...) {
count <- x[, "Count"]
k_threshold <- attr(x, "k_threshold")
if (sum(count[2:3]) == 0) {
cat(paste0("\nAll Pareto k estimates are good (k < ",
round(k_threshold,2), ").\n"))
} else {
tab <- cbind(
" " = rep("", 3),
" " = c("(good)", "(bad)", "(very bad)"),
"Count" = .fr(count, 0),
"Pct. " = paste0(.fr(100 * x[, "Proportion"], digits), "%"),
# Print ESS as n_eff terms has been deprecated
"Min. ESS" = round(x[, "Min. n_eff"])
)
tab2 <- rbind(tab)
cat("Pareto k diagnostic values:\n")
rownames(tab2) <- format(rownames(tab2), justify = "right")
print(tab2, quote = FALSE)
invisible(x)
}
}
#' @rdname pareto-k-diagnostic
#' @export
#' @return `pareto_k_ids()` returns an integer vector indicating which
#' observations have Pareto \eqn{k} estimates above `threshold`.
#'
pareto_k_ids <- function(x, threshold = NULL) {
if (is.null(threshold)) {
S <- dim(x)[1]
threshold <- ps_khat_threshold(S)
}
k <- pareto_k_values(x)
which(k > threshold)
}
#' @rdname pareto-k-diagnostic
#' @export
#' @return `pareto_k_values()` returns a vector of the estimated Pareto
#' \eqn{k} parameters. These represent the reliability of sampling.
pareto_k_values <- function(x) {
k <- x$diagnostics[["pareto_k"]]
if (is.null(k)) {
# for compatibility with objects from loo < 2.0.0
k <- x[["pareto_k"]]
}
if (is.null(k)) {
stop("No Pareto k estimates found.", call. = FALSE)
}
return(k)
}
#' @rdname pareto-k-diagnostic
#' @export
#' @return `pareto_k_influence_values()` returns a vector of the estimated Pareto
#' \eqn{k} parameters. These represent influence of the observations on the
#' model posterior distribution.
pareto_k_influence_values <- function(x) {
if ("influence_pareto_k" %in% colnames(x$pointwise)) {
k <- x$pointwise[,"influence_pareto_k"]
}
else {
stop("No Pareto k influence estimates found.", call. = FALSE)
}
return(k)
}
#' @rdname pareto-k-diagnostic
#' @export
#' @return `psis_n_eff_values()` returns a vector of the estimated PSIS
#' effective sample sizes.
psis_n_eff_values <- function(x) {
n_eff <- x$diagnostics[["n_eff"]]
if (is.null(n_eff)) {
# Print ESS as n_eff terms has been deprecated
stop("No PSIS ESS estimates found.", call. = FALSE)
}
return(n_eff)
}
#' @rdname pareto-k-diagnostic
#' @export
#' @return `mcse_loo()` returns the Monte Carlo standard error (MCSE)
#' estimate for PSIS-LOO. MCSE will be NA if any Pareto \eqn{k} values are
#' above `threshold`.
#'
mcse_loo <- function(x, threshold = NULL) {
stopifnot(is.psis_loo(x))
S <- dim(x)[1]
if (is.null(threshold)) {
k_threshold <- ps_khat_threshold(S)
} else {
k_threshold <- threshold
}
if (any(pareto_k_values(x) > k_threshold, na.rm = TRUE)) {
return(NA)
}
mc_var <- x$pointwise[, "mcse_elpd_loo"]^2
sqrt(sum(mc_var))
}
#' @rdname pareto-k-diagnostic
#' @aliases plot.loo
#' @export
#' @param label_points,... For the `plot()` method, if `label_points` is
#' `TRUE` the observation numbers corresponding to any values of \eqn{k}
#' greater than the diagnostic threshold will be displayed in the plot.
#' Any arguments specified in `...` will be passed to [graphics::text()]
#' and can be used to control the appearance of the labels.
#' @param diagnostic For the `plot` method, which diagnostic should be
#' plotted? The options are `"k"` for Pareto \eqn{k} estimates (the
#' default), or `"ESS"` or `"n_eff"` for PSIS effective sample size estimates.
#' @param main For the `plot()` method, a title for the plot.
#'
#' @return The `plot()` method is called for its side effect and does not
#' return anything. If `x` is the result of a call to [loo()]
#' or [psis()] then `plot(x, diagnostic)` produces a plot of
#' the estimates of the Pareto shape parameters (`diagnostic = "k"`) or
#' estimates of the PSIS effective sample sizes (`diagnostic = "ESS"`).
#'
plot.psis_loo <- function(x,
diagnostic = c("k", "ESS", "n_eff"),
...,
label_points = FALSE,
main = "PSIS diagnostic plot") {
diagnostic <- match.arg(diagnostic)
k <- pareto_k_values(x)
k[is.na(k)] <- 0 # FIXME when reloo is changed to make NA k values -Inf
k_inf <- !is.finite(k)
if (any(k_inf)) {
warning(signif(100 * mean(k_inf), 2),
"% of Pareto k estimates are Inf/NA/NaN and not plotted.")
}
if (diagnostic == "ESS" || diagnostic == "n_eff") {
n_eff <- psis_n_eff_values(x)
} else {
n_eff <- NULL
}
S <- dim(x)[1]
k_threshold <- ps_khat_threshold(S)
plot_diagnostic(
k = k,
n_eff = n_eff,
threshold = k_threshold,
...,
label_points = label_points,
main = main
)
}
#' @export
#' @noRd
#' @rdname pareto-k-diagnostic
plot.loo <- plot.psis_loo
#' @export
#' @rdname pareto-k-diagnostic
plot.psis <- function(x, diagnostic = c("k", "ESS", "n_eff"), ...,
label_points = FALSE,
main = "PSIS diagnostic plot") {
plot.psis_loo(x, diagnostic = diagnostic, ...,
label_points = label_points, main = main)
}
# internal ----------------------------------------------------------------
plot_diagnostic <-
function(k,
n_eff = NULL,
threshold = 0.7,
...,
label_points = FALSE,
main = "PSIS diagnostic plot") {
use_n_eff <- !is.null(n_eff)
graphics::plot(
x = if (use_n_eff) n_eff else k,
xlab = "Data point",
# Print ESS as n_eff terms has been deprecated
ylab = if (use_n_eff) "PSIS ESS" else "Pareto shape k",
type = "n",
bty = "l",
yaxt = "n",
main = main
)
graphics::axis(side = 2, las = 1)
in_range <- function(x, lb_ub) {
x >= lb_ub[1L] & x <= lb_ub[2L]
}
if (!use_n_eff) {
krange <- range(k, na.rm = TRUE)
breaks <- c(0, threshold, 1)
hex_clrs <- c("#C79999", "#7C0000")
ltys <- c(3, 2, 1)
for (j in seq_along(breaks)) {
val <- breaks[j]
if (in_range(val, krange))
graphics::abline(
h = val,
col = ifelse(val == 0, "darkgray", hex_clrs[j - 1]),
lty = ltys[j],
lwd = 1
)
}
}
breaks <- c(-Inf, threshold, 1)
hex_clrs <- c("#6497b1", "#005b96", "#03396c")
clrs <- ifelse(
in_range(k, breaks[1:2]),
hex_clrs[1],
ifelse(in_range(k, breaks[2:3]), hex_clrs[2], hex_clrs[3])
)
if (all(k < threshold) || !label_points) {
graphics::points(x = if (use_n_eff) n_eff else k,
col = clrs, pch = 3, cex = .6)
return(invisible())
} else {
graphics::points(x = if (use_n_eff) n_eff[k < threshold] else k[k < threshold],
col = clrs[k < threshold], pch = 3, cex = .6)
sel <- !in_range(k, breaks[1:2])
dots <- list(...)
txt_args <- c(
list(
x = seq_along(k)[sel],
y = if (use_n_eff) n_eff[sel] else k[sel],
labels = seq_along(k)[sel]
),
if (length(dots)) dots
)
if (!("adj" %in% names(txt_args))) txt_args$adj <- 2 / 3
if (!("cex" %in% names(txt_args))) txt_args$cex <- 0.75
if (!("col" %in% names(txt_args))) txt_args$col <- clrs[sel]
do.call(graphics::text, txt_args)
}
}
#' Convert numeric Pareto k values to a factor variable.
#'
#' @noRd
#' @param k Vector of Pareto k estimates.
#' @return A factor variable (the same length as k) with 3 levels.
#'
k_cut <- function(k, threshold) {
cut(
k,
breaks = c(-Inf, threshold, 1, Inf),
labels = c(paste0("(-Inf, ", round(threshold,2), "]"),
paste0("(", round(threshold,2), ", 1]"),
"(1, Inf)")
)
}
#' Calculate the minimum PSIS n_eff within groups defined by Pareto k values
#'
#' @noRd
#' @param n_eff Vector of PSIS n_eff estimates.
#' @param kcut Factor returned by the k_cut() function.
#' @return Vector of length `nlevels(kcut)` containing the minimum n_eff within
#' each k group. If there are no k values in a group the corresponding element
#' of the returned vector is NA.
min_n_eff_by_k <- function(n_eff, kcut) {
n_eff_split <- split(n_eff, f = kcut)
n_eff_split <- sapply(n_eff_split, function(x) {
# some k groups might be empty.
# split gives numeric(0) but replace with NA
if (!length(x)) NA else x
})
sapply(n_eff_split, min)
}
#' Pareto-smoothing k-hat threshold
#'
#' Given sample size S computes khat threshold for reliable Pareto
#' smoothed estimate (to have small probability of large error). See
#' section 3.2.4, equation (13). Sample sizes 100, 320, 1000, 2200,
#' 10000 correspond to thresholds 0.5, 0.6, 0.67, 0.7, 0.75. Although
#' with bigger sample size S we can achieve estimates with small
#' probability of large error, it is difficult to get accurate MCSE
#' estimates as the bias starts to dominate when k > 0.7 (see Section 3.2.3).
#' Thus the sample size dependend k-ht threshold is capped at 0.7.
#' @param S sample size
#' @param ... unused
#' @return threshold
#' @noRd
ps_khat_threshold <- function(S, ...) {
min(1 - 1 / log10(S), 0.7)
}