|
| 1 | +# Set up ----------------------------------------------------------------------- |
| 2 | + |
| 3 | +library(rstan) |
| 4 | +library(edstan) |
| 5 | +library(loo) |
| 6 | +library(reshape2) |
| 7 | +library(doParallel) |
| 8 | +options(mc.cores = 5) |
| 9 | +options(loo.cores = 5) |
| 10 | + |
| 11 | + |
| 12 | +# Functions -------------------------------------------------------------------- |
| 13 | + |
| 14 | +# Replacement for rstan::get_posterior_means() that returns object with same |
| 15 | +# structure as rstan::extract() |
| 16 | +# stan_fit: A fitted Stan model |
| 17 | +better_posterior_means <- function(stan_fit) { |
| 18 | + draws <- extract(stan_fit, stan_fit@model_pars) |
| 19 | + f <- function(x) { |
| 20 | + dims <- dim(x) |
| 21 | + n_dims <- length(dims) |
| 22 | + if (n_dims == 1) { |
| 23 | + mean(x) |
| 24 | + } else { |
| 25 | + m <- apply(x, 2:n_dims, mean) |
| 26 | + array(m, dim = c(1, dims[-1])) |
| 27 | + } |
| 28 | + } |
| 29 | + lapply(draws, f) |
| 30 | +} |
| 31 | + |
| 32 | + |
| 33 | +# Function to obtain marginal likelihoods with parallel processing. |
| 34 | +# stan_fit: Fitted Stan model |
| 35 | +# data_list: Data list used in fitting model |
| 36 | +# MFUN: Function to calculate marginal likelihood for cluster at a node |
| 37 | +# location. This is application specific. |
| 38 | +# resid_name: Name of residual in Stan program to integrate out |
| 39 | +# sd_name: Name of SD for residual in Stan program |
| 40 | +# n_nodes: Number of adaptive quadrature nodes to use |
| 41 | +# best_only: Whether to evaluate marginal likelihood only at posterior means |
| 42 | +mll_parallel <- function(stan_fit, data_list, MFUN, resid_name, sd_name, n_nodes, |
| 43 | + best_only = FALSE) { |
| 44 | + |
| 45 | + library(foreach) |
| 46 | + library(statmod) # For gauss.quad.prob() |
| 47 | + library(matrixStats) # For logSumExp() |
| 48 | + |
| 49 | + draws <- extract(stan_fit, stan_fit@model_pars) |
| 50 | + n_iter <- ifelse(best_only, 0, nrow(draws$lp__)) |
| 51 | + post_means <- better_posterior_means(stan_fit) |
| 52 | + |
| 53 | + # Seperate out draws for residuals and their SD |
| 54 | + resid <- apply(draws[[resid_name]], 2, mean) |
| 55 | + stddev <- apply(draws[[resid_name]], 2, sd) |
| 56 | + |
| 57 | + # Get standard quadrature points |
| 58 | + std_quad <- gauss.quad.prob(n_nodes, "normal", mu = 0, sigma = 1) |
| 59 | + std_log_weights <- log(std_quad$weights) |
| 60 | + |
| 61 | + # Extra iteration is to evaluate marginal log-likelihood at parameter means. |
| 62 | + ll <- foreach(i = 1:(n_iter + 1), .combine = rbind, |
| 63 | + .packages = "matrixStats") %dopar% { |
| 64 | + |
| 65 | + ll_j <- matrix(NA, nrow = 1, ncol = ncol(draws[[resid_name]])) |
| 66 | + |
| 67 | + for(j in 1:ncol(ll_j)) { |
| 68 | + |
| 69 | + # Set up adaptive quadrature using SD for residuals either from draws or |
| 70 | + # posterior mean (for best_ll). |
| 71 | + sd_i <- ifelse(i <= n_iter, draws[[sd_name]][i], post_means[[sd_name]]) |
| 72 | + adapt_nodes <- resid[j] + stddev[j] * std_quad$nodes |
| 73 | + log_weights <- log(sqrt(2*pi)) + log(stddev[j]) + std_quad$nodes^2/2 + |
| 74 | + dnorm(adapt_nodes, sd = sd_i, log = TRUE) + std_log_weights |
| 75 | + |
| 76 | + # Evaluate mll with adaptive quadrature. If at n_iter + 1, evaluate |
| 77 | + # marginal likelihood at posterior means. |
| 78 | + if(i <= n_iter) { |
| 79 | + loglik_by_node <- sapply(adapt_nodes, FUN = MFUN, r = j, iter = i, |
| 80 | + data_list = data_list, draws = draws) |
| 81 | + weighted_loglik_by_node <- loglik_by_node + log_weights |
| 82 | + ll_j[1,j] <- logSumExp(weighted_loglik_by_node) |
| 83 | + } else { |
| 84 | + loglik_by_node <- sapply(adapt_nodes, FUN = MFUN, r = j, iter = 1, |
| 85 | + data_list = data_list, draws = post_means) |
| 86 | + weighted_loglik_by_node <- loglik_by_node + log_weights |
| 87 | + ll_j[1,j] <- logSumExp(weighted_loglik_by_node) |
| 88 | + } |
| 89 | + |
| 90 | + } |
| 91 | + |
| 92 | + ll_j |
| 93 | + |
| 94 | + } |
| 95 | + |
| 96 | + if(best_only) { |
| 97 | + return(ll[nrow(ll), ]) |
| 98 | + } else { |
| 99 | + return(list(ll = ll[-nrow(ll), ], best_ll = ll[nrow(ll), ])) |
| 100 | + } |
| 101 | + |
| 102 | +} |
| 103 | + |
| 104 | + |
| 105 | +# Function to calculate likelihood for a cluster for an adaptive quad node |
| 106 | +# specific to the IRT example. Similar functions would be written for other |
| 107 | +# applications and passed to mll_parallel(). |
| 108 | +# node: node location |
| 109 | +# r: index for cluster |
| 110 | +# iter: mcmc iteration |
| 111 | +# data_list: data used to fit Stan model |
| 112 | +# draws: mcmc draws from fitted Stan model |
| 113 | +f_marginal <- function(node, r, iter, data_list, draws) { |
| 114 | + y <- data_list$y[data_list$jj == r] |
| 115 | + theta_fix <- draws$theta_fix[iter, r] |
| 116 | + delta <- draws$delta[iter, data_list$ii[data_list$jj == r]] |
| 117 | + p <- boot::inv.logit(theta_fix + node - delta) |
| 118 | + sum(dbinom(y, 1, p, log = TRUE)) |
| 119 | +} |
| 120 | + |
| 121 | + |
| 122 | +# Function to calculate DIC |
| 123 | +# ll_obj: Object returned by mll_parallel() |
| 124 | +dic <- function(ll_obj) { |
| 125 | + full_ll <- apply(ll_obj$ll, 1, sum) |
| 126 | + full_best <- sum(ll_obj$best_ll) |
| 127 | + mean_lpd <- mean(full_ll) |
| 128 | + pdic <- 2 * (full_best - mean_lpd) |
| 129 | + elpd_dic <- full_best - pdic |
| 130 | + c(elpd_dic = elpd_dic, p_dic = pdic, dic = -2*elpd_dic, |
| 131 | + best_lpd = full_best, mean_lpd = mean_lpd) |
| 132 | +} |
| 133 | + |
| 134 | + |
| 135 | +# Example analysis ------------------------------------------------------------- |
| 136 | + |
| 137 | +# Assemble example dataset |
| 138 | +dl <- irt_data(y = aggression$dich, jj = aggression$person, |
| 139 | + ii = aggression$item, covariates = aggression, |
| 140 | + formula = ~ 1 + male + anger) |
| 141 | + |
| 142 | +# Fit model |
| 143 | +fit <- stan("rasch_edstan_modified.stan", data = dl, iter = 500, chains = 5) |
| 144 | + |
| 145 | +# Obtain marginal likelihoods |
| 146 | +cl <- makeCluster(5) |
| 147 | +registerDoParallel(cl) |
| 148 | +ll_marg <- mll_parallel(fit, dl, f_marginal, "zeta", "sigma", 11) |
| 149 | +stopCluster(cl) |
| 150 | + |
| 151 | +# Obtain marginal information criteria |
| 152 | +dic(ll_marg) |
| 153 | +waic(ll_marg$ll) |
| 154 | +loo(ll_marg$ll) |
| 155 | + |
| 156 | + |
0 commit comments