-
-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathpsislw.R
197 lines (180 loc) · 6 KB
/
psislw.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#' Pareto smoothed importance sampling (deprecated, old version)
#'
#' As of version `2.0.0` this function is **deprecated**. Please use the
#' [psis()] function for the new PSIS algorithm.
#'
#' @export
#' @param lw A matrix or vector of log weights. For computing LOO, `lw =
#' -log_lik`, the *negative* of an \eqn{S} (simulations) by \eqn{N} (data
#' points) pointwise log-likelihood matrix.
#' @param wcp The proportion of importance weights to use for the generalized
#' Pareto fit. The `100*wcp`\% largest weights are used as the sample
#' from which to estimate the parameters of the generalized Pareto
#' distribution.
#' @param wtrunc For truncating very large weights to \eqn{S}^`wtrunc`. Set
#' to zero for no truncation.
#' @param cores The number of cores to use for parallelization. This defaults to
#' the option `mc.cores` which can be set for an entire R session by
#' `options(mc.cores = NUMBER)`, the old option `loo.cores` is now
#' deprecated but will be given precedence over `mc.cores` until it is
#' removed. **As of version 2.0.0, the default is now 1 core if
#' `mc.cores` is not set, but we recommend using as many (or close to as
#' many) cores as possible.**
#' @param llfun,llargs See [loo.function()].
#' @param ... Ignored when `psislw()` is called directly. The `...` is
#' only used internally when `psislw()` is called by the [loo()]
#' function.
#'
#' @return A named list with components `lw_smooth` (modified log weights) and
#' `pareto_k` (estimated generalized Pareto shape parameter(s) k).
#'
#' @seealso [pareto-k-diagnostic] for PSIS diagnostics.
#'
#' @template loo-and-psis-references
#'
#' @importFrom parallel mclapply makePSOCKcluster stopCluster parLapply
#'
psislw <- function(lw, wcp = 0.2, wtrunc = 3/4,
cores = getOption("mc.cores", 1),
llfun = NULL, llargs = NULL,
...) {
.Deprecated("psis")
cores <- loo_cores(cores)
.psis <- function(lw_i) {
x <- lw_i - max(lw_i)
cutoff <- lw_cutpoint(x, wcp, MIN_CUTOFF)
above_cut <- x > cutoff
x_body <- x[!above_cut]
x_tail <- x[above_cut]
tail_len <- length(x_tail)
if (tail_len < MIN_TAIL_LENGTH || all(x_tail == x_tail[1])) {
if (all(x_tail == x_tail[1]))
warning(
"All tail values are the same. ",
"Weights are truncated but not smoothed.",
call. = FALSE
)
else if (tail_len < MIN_TAIL_LENGTH)
warning(
"Too few tail samples to fit generalized Pareto distribution.\n",
"Weights are truncated but not smoothed.",
call. = FALSE
)
x_new <- x
k <- Inf
} else {
# store order of tail samples, fit gPd to the right tail samples, compute
# order statistics for the fit, remap back to the original order, join
# body and gPd smoothed tail
tail_ord <- order(x_tail)
exp_cutoff <- exp(cutoff)
fit <- gpdfit(exp(x_tail) - exp_cutoff, wip=FALSE, min_grid_pts = 80)
k <- fit$k
sigma <- fit$sigma
prb <- (seq_len(tail_len) - 0.5) / tail_len
qq <- qgpd(prb, k, sigma) + exp_cutoff
smoothed_tail <- rep.int(0, tail_len)
smoothed_tail[tail_ord] <- log(qq)
x_new <- x
x_new[!above_cut] <- x_body
x_new[above_cut] <- smoothed_tail
}
# truncate (if wtrunc > 0) and renormalize,
# return log weights and pareto k
lw_new <- lw_normalize(lw_truncate(x_new, wtrunc))
nlist(lw_new, k)
}
.psis_loop <- function(i) {
if (LL_FUN) {
ll_i <- llfun(i = i,
data = llargs$data[i,, drop=FALSE],
draws = llargs$draws)
lw_i <- -1 * ll_i
} else {
lw_i <- lw[, i]
ll_i <- -1 * lw_i
}
psis <- .psis(lw_i)
if (FROM_LOO)
nlist(lse = logSumExp(ll_i + psis$lw_new), k = psis$k)
else
psis
}
# minimal cutoff value. there must be at least 5 log-weights larger than this
# in order to fit the gPd to the tail
MIN_CUTOFF <- -700
MIN_TAIL_LENGTH <- 5
dots <- list(...)
FROM_LOO <- if ("COMPUTE_LOOS" %in% names(dots))
dots$COMPUTE_LOOS else FALSE
if (!missing(lw)) {
if (!is.matrix(lw))
lw <- as.matrix(lw)
N <- ncol(lw)
LL_FUN <- FALSE
} else {
if (is.null(llfun) || is.null(llargs))
stop("Either 'lw' or 'llfun' and 'llargs' must be specified.")
N <- llargs$N
LL_FUN <- TRUE
}
if (cores == 1) {
# don't call functions from parallel package if cores=1
out <- lapply(X = 1:N, FUN = .psis_loop)
} else {
# parallelize
if (.Platform$OS.type != "windows") {
out <- mclapply(X = 1:N, FUN = .psis_loop, mc.cores = cores)
} else {
# nocov start
cl <- makePSOCKcluster(cores)
on.exit(stopCluster(cl))
out <- parLapply(cl, X = 1:N, fun = .psis_loop)
# nocov end
}
}
pareto_k <- vapply(out, "[[", 2L, FUN.VALUE = numeric(1))
psislw_warnings(pareto_k)
if (FROM_LOO) {
loos <- vapply(out, "[[", 1L, FUN.VALUE = numeric(1))
nlist(loos, pareto_k)
} else {
funval <- if (LL_FUN) llargs$S else nrow(lw)
lw_smooth <- vapply(out, "[[", 1L, FUN.VALUE = numeric(funval))
out <- nlist(lw_smooth, pareto_k)
class(out) <- c("psis", "list")
return(out)
}
}
# internal ----------------------------------------------------------------
lw_cutpoint <- function(y, wcp, min_cut) {
if (min_cut < log(.Machine$double.xmin))
min_cut <- -700
cp <- quantile(y, 1 - wcp, names = FALSE)
max(cp, min_cut)
}
lw_truncate <- function(y, wtrunc) {
if (wtrunc == 0)
return(y)
logS <- log(length(y))
lwtrunc <- wtrunc * logS - logS + logSumExp(y)
y[y > lwtrunc] <- lwtrunc
y
}
lw_normalize <- function(y) {
y - logSumExp(y)
}
# warnings about pareto k values ------------------------------------------
psislw_warnings <- function(k) {
if (any(k > 0.7)) {
.warn(
"Some Pareto k diagnostic values are too high. ",
.k_help()
)
} else if (any(k > 0.5)) {
.warn(
"Some Pareto k diagnostic values are slightly high. ",
.k_help()
)
}
}