-
-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathhelpers.R
199 lines (178 loc) · 4.98 KB
/
helpers.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#' Detect if OS is Windows
#' @noRd
os_is_windows <- function() {
checkmate::test_os("windows")
}
#' More stable version of `log(mean(exp(x)))`
#'
#' @noRd
#' @param x A numeric vector.
#' @return A scalar equal to `log(mean(exp(x)))`.
#'
logMeanExp <- function(x) {
logS <- log(length(x))
matrixStats::logSumExp(x) - logS
}
#' More stable version of `log(colMeans(exp(x)))`
#'
#' @noRd
#' @param x A matrix.
#' @return A vector where each element is `logMeanExp()` of a column of `x`.
#'
colLogMeanExps <- function(x) {
logS <- log(nrow(x))
matrixStats::colLogSumExps(x) - logS
}
#' Compute point estimates and standard errors from pointwise vectors
#'
#' @noRd
#' @param x A matrix.
#' @return An `ncol(x)` by 2 matrix with columns `"Estimate"` and `"SE"`
#' and rownames equal to `colnames(x)`.
#'
table_of_estimates <- function(x) {
out <- cbind(
Estimate = matrixStats::colSums2(x),
SE = sqrt(nrow(x) * matrixStats::colVars(x))
)
rownames(out) <- colnames(x)
return(out)
}
# validating and reshaping arrays/matrices -------------------------------
#' Check for `NA` and non-finite values in log-lik (or log-ratios)
#' array/matrix/vector
#'
#' @noRd
#' @param x Array/matrix/vector of log-likelihood or log-ratio values.
#' @return `x`, invisibly, if no error is thrown.
#'
validate_ll <- function(x) {
if (is.list(x)) {
stop("List not allowed as input.")
} else if (anyNA(x)) {
stop("NAs not allowed in input.")
} else if (any(x == Inf)) {
stop("All input values must be finite or -Inf.")
}
invisible(x)
}
#' Convert iter by chain by obs array to (iter * chain) by obs matrix
#'
#' @noRd
#' @param x Array to convert.
#' @return An (iter * chain) by obs matrix.
#'
llarray_to_matrix <- function(x) {
stopifnot(is.array(x), length(dim(x)) == 3)
xdim <- dim(x)
dim(x) <- c(prod(xdim[1:2]), xdim[3])
unname(x)
}
#' Convert (iter * chain) by obs matrix to iter by chain by obs array
#'
#' @noRd
#' @param x matrix to convert.
#' @param chain_id vector of chain ids.
#' @return iter by chain by obs array
#'
llmatrix_to_array <- function(x, chain_id) {
stopifnot(is.matrix(x), all(chain_id == as.integer(chain_id)))
lldim <- dim(x)
n_chain <- length(unique(chain_id))
chain_id <- as.integer(chain_id)
chain_counts <- as.numeric(table(chain_id))
if (length(chain_id) != lldim[1]) {
stop("Number of rows in matrix not equal to length(chain_id).",
call. = FALSE)
} else if (any(chain_counts != chain_counts[1])) {
stop("Not all chains have same number of iterations.",
call. = FALSE)
} else if (max(chain_id) != n_chain) {
stop("max(chain_id) not equal to the number of chains.",
call. = FALSE)
}
n_iter <- lldim[1] / n_chain
n_obs <- lldim[2]
a <- array(data = NA, dim = c(n_iter, n_chain, n_obs))
for (c in seq_len(n_chain)) {
a[, c, ] <- x[chain_id == c, , drop = FALSE]
}
return(a)
}
#' Validate that log-lik function exists and has correct arg names
#'
#' @noRd
#' @param x A function with arguments `data_i` and `draws`.
#' @return Either returns `x` or throws an error.
#'
validate_llfun <- function(x) {
f <- match.fun(x)
must_have <- c("data_i", "draws")
arg_names <- names(formals(f))
if (!all(must_have %in% arg_names)) {
stop(
"Log-likelihood function must have at least the arguments ",
"'data_i' and 'draws'",
call. = FALSE
)
}
return(f)
}
#' Named lists
#'
#' Create a named list using specified names or, if names are omitted, using the
#' names of the objects in the list. The code `list(a = a, b = b)` becomes
#' `nlist(a,b)` and `list(a = a, b = 2)` becomes `nlist(a, b = 2)`, etc.
#'
#' @export
#' @keywords internal
#' @param ... Objects to include in the list.
#' @return A named list.
#' @examples
#'
#' # All variables already defined
#' a <- rnorm(100)
#' b <- mat.or.vec(10, 3)
#' nlist(a,b)
#'
#' # Define some variables in the call and take the rest from the environment
#' nlist(a, b, veggies = c("lettuce", "spinach"), fruits = c("banana", "papaya"))
#'
nlist <- function(...) {
m <- match.call()
out <- list(...)
no_names <- is.null(names(out))
has_name <- if (no_names) FALSE else nzchar(names(out))
if (all(has_name))
return(out)
nms <- as.character(m)[-1L]
if (no_names) {
names(out) <- nms
} else {
names(out)[!has_name] <- nms[!has_name]
}
return(out)
}
# Check how many cores to use and throw deprecation warning if loo.cores is used
loo_cores <- function(cores) {
loo_cores_op <- getOption("loo.cores", NA)
if (!is.na(loo_cores_op) && (loo_cores_op != cores)) {
cores <- loo_cores_op
warning("'loo.cores' is deprecated, please use 'mc.cores' or pass 'cores' explicitly.",
call. = FALSE)
}
return(cores)
}
# nocov start
# release reminders (for devtools)
release_questions <- function() {
c(
"Have you updated references?",
"Have you updated inst/CITATION?",
"Have you updated the vignettes?"
)
}
# nocov end
is_constant <- function(x, tol = .Machine$double.eps) {
abs(max(x) - min(x)) < tol
}