Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIC and BIC #54

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Imports:
rstan (>= 2.26.0),
rstantools (>= 2.3.0),
stats,
stringr,
tibble,
tidyr (>= 1.3.0)
LinkingTo:
Expand Down
26 changes: 19 additions & 7 deletions R/model-evaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,22 @@ NULL

#' @export
#' @rdname model_evaluation
add_criterion <- function(x, criterion = c("loo", "waic"), overwrite = FALSE,
save = TRUE, ..., r_eff = NA) {
add_criterion <- function(x, criterion = c("loo", "waic"),
overwrite = FALSE, save = TRUE, ..., r_eff = NA) {
model <- check_model(x, required_class = "measrfit", name = "x")
if (model$method != "mcmc") {
if (any(model$method != "mcmc" & criterion %in% c("loo", "waic"))) {
rlang::abort("error_bad_method",
message = glue::glue("Model criteria are only available for ",
"models estimated with ",
message = glue::glue("LOO and WAIC model criteria are only ",
"available for models estimated with ",
"`method = \"mcmc\"`."))
} else if (any(model$method != "optim" & criterion %in% c("aic", "bic"))) {
rlang::abort("error_bad_method",
message = glue::glue("AIC and BIC model criteria are only ",
"available for models estimated with ",
"`method = \"optim\"`."))
}
criterion <- rlang::arg_match(criterion, values = c("loo", "waic"),
criterion <- rlang::arg_match(criterion,
values = c("loo", "waic", "aic", "bic"),
multiple = TRUE)
overwrite <- check_logical(overwrite, name = "overwrite")
save <- check_logical(save, name = "force_save")
Expand All @@ -140,7 +146,7 @@ add_criterion <- function(x, criterion = c("loo", "waic"), overwrite = FALSE,
}
all_criteria <- c(new_criteria, redo_criteria)

if (length(all_criteria) > 0) {
if (length(all_criteria) > 0 & (model$method == "mcmc")) {
log_lik_array <- loglik_array(model)
}

Expand All @@ -150,6 +156,12 @@ add_criterion <- function(x, criterion = c("loo", "waic"), overwrite = FALSE,
if ("waic" %in% all_criteria) {
model$criteria$waic <- waic(log_lik_array)
}
if ("aic" %in% all_criteria) {
model$criteria$aic <- aic(model)
}
if ("bic" %in% all_criteria) {
model$criteria$bic <- bic(model)
}

# re-save model object (if applicable)
if (!is.null(model$file) && length(all_criteria) > 0 && save) {
Expand Down
34 changes: 34 additions & 0 deletions R/utils-model-evaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,37 @@ add_ppmc <- function(model, run_ppmc) {

return(model)
}

aic <- function(model) {
logLik <- model$model$value

num_params <- model$model$par %>%
tibble::as_tibble() %>%
dplyr::mutate(param = names(model$model$par)) %>%
dplyr::filter(!stringr::str_detect(.data$param, "pi")) %>%
dplyr::filter(!stringr::str_detect(.data$param, "log_Vc")) %>%
nrow() - 1

aic <- (-2 * logLik) + (2 * num_params)

return(aic)
}

bic <- function(model) {
logLik <- model$model$value

num_params <- model$model$par %>%
tibble::as_tibble() %>%
dplyr::mutate(param = names(model$model$par)) %>%
dplyr::filter(!stringr::str_detect(.data$param, "pi")) %>%
dplyr::filter(!stringr::str_detect(.data$param, "log_Vc")) %>%
nrow() - 1

N <- model$data$data %>%
dplyr::distinct(.data$resp_id) %>%
nrow()

bic <- (-2 * logLik) + (log(N) * num_params)

return(bic)
}
11 changes: 11 additions & 0 deletions tests/testthat/test-ecpe.R
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,14 @@ test_that("mcmc requirements error", {
expect_s3_class(err, "error_bad_method")
expect_match(err$message, "`method = \"mcmc\"`")
})

test_that("optim requirements error", {
skip_on_cran()

mcmc_mod <- cmds_ecpe_lcdm
mcmc_mod$method <- "mcmc"

err <- rlang::catch_cnd(add_criterion(mcmc_mod, criterion = c("aic", "bic")))
expect_s3_class(err, "error_bad_method")
expect_match(err$message, "`method = \"optim\"`")
})
3 changes: 1 addition & 2 deletions tests/testthat/test-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ test_that("loo and waic can be added to model", {
expect_equal(names(loo_model$criteria), "loo")
expect_s3_class(loo_model$criteria$loo, "psis_loo")

lw_model <- add_criterion(loo_model, criterion = c("loo", "waic"),
overwrite = TRUE)
lw_model <- add_criterion(loo_model, overwrite = TRUE)
expect_equal(names(lw_model$criteria), c("loo", "waic"))
expect_s3_class(lw_model$criteria$loo, "psis_loo")
expect_s3_class(lw_model$criteria$waic, "waic")
Expand Down
40 changes: 40 additions & 0 deletions tests/testthat/test-model-evaluation.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
test_that("add criterion error messages work", {
err <- rlang::catch_cnd(add_criterion("test"))
expect_s3_class(err, "error_bad_argument")
expect_match(err$message, "must be an object with class measrfit")


err <- rlang::catch_cnd(add_criterion(rstn_dino, criterion = "waic"))
expect_s3_class(err, "error_bad_method")
expect_match(err$message, "LOO and WAIC model criteria are only available")

err <- rlang::catch_cnd(add_criterion(rstn_dino, criterion = "waic"))
expect_s3_class(err, "error_bad_method")
expect_match(err$message, "LOO and WAIC model criteria are only available")

test_dino <- rstn_dino
test_dino$method <- "mcmc"
err <- rlang::catch_cnd(add_criterion(test_dino, criterion = "aic"))
expect_s3_class(err, "error_bad_method")
expect_match(err$message, "AIC and BIC model criteria are only available")

err <- rlang::catch_cnd(add_criterion(test_dino, criterion = "bic"))
expect_s3_class(err, "error_bad_method")
expect_match(err$message, "AIC and BIC model criteria are only available")
})

test_that("AIC works", {
rstn_dino <- add_criterion(rstn_dino, criterion = "aic")

exp_aic <- 37151.96

expect_equal(rstn_dino$criteria$aic, exp_aic, tolerance = .0001)
})

test_that("BIC works", {
rstn_dino <- add_criterion(rstn_dino, criterion = "bic")

exp_bic <- 37647.64

expect_equal(rstn_dino$criteria$bic, exp_bic, tolerance = .0001)
})
20 changes: 20 additions & 0 deletions tests/testthat/test-utils-model-evaluation.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
test_that("aic works", {
num_params <- 101
logLik <- -18474.98

exp_aic <- (-2 * logLik) + (2 * num_params)
aic_val <- aic(rstn_dino)

expect_equal(exp_aic, aic_val)
})

test_that("bic works", {
num_params <- 101
N <- 1000
logLik <- -18474.98

exp_bic <- (-2 * logLik) + (log(N) * num_params)
bic_val <- bic(rstn_dino)

expect_equal(exp_bic, bic_val)
})