-
-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathloo-functions.R
101 lines (93 loc) · 2.99 KB
/
loo-functions.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#' Generic functions for LOO predictions
#'
#' See the methods in the \pkg{rstanarm} package for examples.
#'
#' @name loo-prediction
#'
#' @template args-object
#' @template args-dots
#'
#' @return `loo_predict()`, `loo_epred()`, `loo_linpred()`, and `loo_pit()`
#' (probability integral transform) methods should return a vector with
#' length equal to the number of observations in the data.
#' For discrete observations, probability integral transform is randomised to
#' ensure theoretical uniformity. Fix random seed for reproducible results
#' with discrete data. For more details, see Czado et al. (2009).
#' `loo_predictive_interval()` methods should return a two-column matrix
#' formatted in the same way as for [predictive_interval()].
#'
#' @template seealso-rstanarm-pkg
#' @template seealso-vignettes
#' @template reference-randomised-pit
#' @rdname loo-prediction
#' @export
loo_linpred <- function(object, ...) {
UseMethod("loo_linpred")
}
#' @rdname loo-prediction
#' @export
loo_epred <- function(object, ...) {
UseMethod("loo_epred")
}
#' @rdname loo-prediction
#' @export
loo_predict <- function(object, ...) {
UseMethod("loo_predict")
}
#' @rdname loo-prediction
#' @export
loo_predictive_interval <- function(object, ...) {
UseMethod("loo_predictive_interval")
}
#' @rdname loo-prediction
#' @export
loo_pit <- function(object, ...) {
UseMethod("loo_pit")
}
#' @rdname loo-prediction
#' @export
#' @param y For the default method of `loo_pit()`, a vector of `y` values the
#' same length as the number of columns in the matrix used as `object`.
#' @param lw For the default method of `loo_pit()`, a matrix of log-weights of
#' the same length as the number of columns in the matrix used as `object`.
#'
loo_pit.default <- function(object, y, lw, ...) {
if (!is.matrix(object))
stop("For the default method 'object' should be a matrix.")
stopifnot(
is.numeric(object), is.numeric(y), length(y) == ncol(object),
is.matrix(lw), identical(dim(lw), dim(object))
)
.loo_pit(y = y, yrep = object, lw = lw)
}
# internal ----------------------------------------------------------------
.loo_pit <- function(y, yrep, lw) {
if (is.null(lw) || !all(is.finite(lw))) {
stop("lw needs to be not null and finite.")
}
pits <- vapply(seq_len(ncol(yrep)), function(j) {
sel_min <- yrep[, j] < y[j]
pit <- .exp_log_sum_exp(lw[sel_min, j])
sel_sup <- yrep[, j] == y[j]
if (any(sel_sup)) {
# randomized PIT for discrete y (see, e.g., Czado, C., Gneiting, T.,
# Held, L.: Predictive model assessment for count data.
# Biometrics 65(4), 1254–1261 (2009).)
pit_sup <- pit + .exp_log_sum_exp(lw[sel_sup, j])
pit <- stats::runif(1, pit, pit_sup)
}
pit
}, FUN.VALUE = 1)
if (any(pits > 1)) {
warning(cat(
"Some PIT values larger than 1! Largest: ",
max(pits),
"\nRounding PIT > 1 to 1."
))
}
pmin(1, pits)
}
.exp_log_sum_exp <- function(x) {
m <- suppressWarnings(max(x))
exp(m + log(sum(exp(x - m))))
}