From d3dab3d72f9f00e8de0ff52ed5065deec6621927 Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 12:22:48 -0500 Subject: [PATCH 01/13] add AIC & BIC with error messaging --- R/model-evaluation.R | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/R/model-evaluation.R b/R/model-evaluation.R index 51ffeeb..4e90594 100644 --- a/R/model-evaluation.R +++ b/R/model-evaluation.R @@ -117,14 +117,19 @@ NULL #' @export #' @rdname model_evaluation -add_criterion <- function(x, criterion = c("loo", "waic"), overwrite = FALSE, - save = TRUE, ..., r_eff = NA) { +add_criterion <- function(x, criterion = c("loo", "waic", "aic", "bic"), + overwrite = FALSE, save = TRUE, ..., r_eff = NA) { model <- check_model(x, required_class = "measrfit", name = "x") - if (model$method != "mcmc") { + if (model$method != "mcmc" & criterion %in% c("loo", "waic")) { rlang::abort("error_bad_method", - message = glue::glue("Model criteria are only available for ", - "models estimated with ", + message = glue::glue("LOO and WAIC model criteria are only ", + "available for models estimated with ", "`method = \"mcmc\"`.")) + } else if (model$method != "optim" & criterion %in% c("aic", "bic")) { + rlang::abort("error_bad_method", + message = glue::glue("AIC and BIC model criteria are only ", + "available for models estimated with ", + "`method = \"optim\"`.")) } criterion <- rlang::arg_match(criterion, values = c("loo", "waic"), multiple = TRUE) From e240e9434cdaf10c44a44b9237dbf8b421fc81d4 Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 12:33:15 -0500 Subject: [PATCH 02/13] adding aic & bic --- R/model-evaluation.R | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/R/model-evaluation.R b/R/model-evaluation.R index 4e90594..2152dfa 100644 --- a/R/model-evaluation.R +++ b/R/model-evaluation.R @@ -131,7 +131,8 @@ add_criterion <- function(x, criterion = c("loo", "waic", "aic", "bic"), "available for models estimated with ", "`method = \"optim\"`.")) } - criterion <- rlang::arg_match(criterion, values = c("loo", "waic"), + criterion <- rlang::arg_match(criterion, + values = c("loo", "waic", "aic", "bic"), multiple = TRUE) overwrite <- check_logical(overwrite, name = "overwrite") save <- check_logical(save, name = "force_save") @@ -145,8 +146,10 @@ add_criterion <- function(x, criterion = c("loo", "waic", "aic", "bic"), } all_criteria <- c(new_criteria, redo_criteria) - if (length(all_criteria) > 0) { + if (length(all_criteria) > 0 & (model$method == "mcmc")) { log_lik_array <- loglik_array(model) + } else if (length(all_criteria) > 0 & (model$method == "optim")) { + log_lik_array <- model$model$value } if ("loo" %in% all_criteria) { @@ -155,6 +158,12 @@ add_criterion <- function(x, criterion = c("loo", "waic", "aic", "bic"), if ("waic" %in% all_criteria) { model$criteria$waic <- waic(log_lik_array) } + if ("aic" %in% all_criteria) { + model$criteria$aic <- aic(log_lik_array) + } + if ("=bic" %in% all_criteria) { + model$criteria$bic <- bic(log_lik_array) + } # re-save model object (if applicable) if (!is.null(model$file) && length(all_criteria) > 0 && save) { From 666100419f1f85c6b083acb647b5a87c38eb7961 Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 12:48:20 -0500 Subject: [PATCH 03/13] aic & bic util functions --- R/utils-model-evaluation.R | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/R/utils-model-evaluation.R b/R/utils-model-evaluation.R index 9944ca1..95bcaec 100644 --- a/R/utils-model-evaluation.R +++ b/R/utils-model-evaluation.R @@ -62,3 +62,37 @@ add_ppmc <- function(model, run_ppmc) { return(model) } + +aic <- function(model) { + logLik <- model$model$value + + num_params <- model$model$par %>% + tibble::as_tibble() %>% + dplyr::mutate(param = names(model$model$par)) %>% + dplyr::filter(!stringr::str_detect(param, "pi")) %>% + dplyr::filter(!stringr::str_detect(param, "log_Vc")) %>% + nrow() - 1 + + aic <- (-2 * logLik) + (2 * num_params) + + return(aic) +} + +bic <- function(model) { + logLik <- model$model$value + + num_params <- model$model$par %>% + tibble::as_tibble() %>% + dplyr::mutate(param = names(model$model$par)) %>% + dplyr::filter(!stringr::str_detect(param, "pi")) %>% + dplyr::filter(!stringr::str_detect(param, "log_Vc")) %>% + nrow() - 1 + + N <- model$data$data %>% + dplyr::distinct(resp_id) %>% + nrow() + + bic <- (-2 * logLik) + (log(N) * num_params) + + return(bic) +} From 934712a64ff113a879336f1dcfb5e5b43284cdcb Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 12:49:45 -0500 Subject: [PATCH 04/13] passing model into aic & bic functions --- R/model-evaluation.R | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/R/model-evaluation.R b/R/model-evaluation.R index 2152dfa..0019751 100644 --- a/R/model-evaluation.R +++ b/R/model-evaluation.R @@ -148,8 +148,6 @@ add_criterion <- function(x, criterion = c("loo", "waic", "aic", "bic"), if (length(all_criteria) > 0 & (model$method == "mcmc")) { log_lik_array <- loglik_array(model) - } else if (length(all_criteria) > 0 & (model$method == "optim")) { - log_lik_array <- model$model$value } if ("loo" %in% all_criteria) { @@ -159,10 +157,10 @@ add_criterion <- function(x, criterion = c("loo", "waic", "aic", "bic"), model$criteria$waic <- waic(log_lik_array) } if ("aic" %in% all_criteria) { - model$criteria$aic <- aic(log_lik_array) + model$criteria$aic <- aic(model) } - if ("=bic" %in% all_criteria) { - model$criteria$bic <- bic(log_lik_array) + if ("bic" %in% all_criteria) { + model$criteria$bic <- bic(model) } # re-save model object (if applicable) From 1913a6143dadf591e19d6dd24ea4f57385ee18cf Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 14:37:17 -0500 Subject: [PATCH 05/13] update unit tests for updated criterion error messages --- tests/testthat/test-ecpe.R | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test-ecpe.R b/tests/testthat/test-ecpe.R index 43a1514..98d6430 100644 --- a/tests/testthat/test-ecpe.R +++ b/tests/testthat/test-ecpe.R @@ -289,7 +289,26 @@ test_that("mcmc requirements error", { expect_s3_class(err, "error_bad_method") expect_match(err$message, "`method = \"mcmc\"`") - err <- rlang::catch_cnd(add_criterion(cmds_ecpe_lcdm)) + err <- rlang::catch_cnd(add_criterion(cmds_ecpe_lcdm, "loo")) expect_s3_class(err, "error_bad_method") expect_match(err$message, "`method = \"mcmc\"`") + + err <- rlang::catch_cnd(add_criterion(cmds_ecpe_lcdm, "waic")) + expect_s3_class(err, "error_bad_method") + expect_match(err$message, "`method = \"mcmc\"`") +}) + +test_that("optim requirements error", { + skip_on_cran() + + mcmc_mod <- cmds_ecpe_lcdm + mcmc_mod$method <- "mcmc" + + err <- rlang::catch_cnd(add_criterion(mcmc_mod, "aic")) + expect_s3_class(err, "error_bad_method") + expect_match(err$message, "`method = \"optim\"`") + + err <- rlang::catch_cnd(add_criterion(mcmc_mod, "bic")) + expect_s3_class(err, "error_bad_method") + expect_match(err$message, "`method = \"optim\"`") }) From 131415c2527fa41797243bec5c04b9eac7e431f9 Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 14:37:27 -0500 Subject: [PATCH 06/13] roxygenize --- man/model_evaluation.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/model_evaluation.Rd b/man/model_evaluation.Rd index 73860d0..f5b6e93 100644 --- a/man/model_evaluation.Rd +++ b/man/model_evaluation.Rd @@ -10,7 +10,7 @@ \usage{ add_criterion( x, - criterion = c("loo", "waic"), + criterion = c("loo", "waic", "aic", "bic"), overwrite = FALSE, save = TRUE, ..., From 8cb8c5cc3a7f89fb5f2600bd199e322b8151cb1a Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 14:37:36 -0500 Subject: [PATCH 07/13] anchor variables --- R/utils-model-evaluation.R | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/R/utils-model-evaluation.R b/R/utils-model-evaluation.R index 95bcaec..89c0d11 100644 --- a/R/utils-model-evaluation.R +++ b/R/utils-model-evaluation.R @@ -69,8 +69,8 @@ aic <- function(model) { num_params <- model$model$par %>% tibble::as_tibble() %>% dplyr::mutate(param = names(model$model$par)) %>% - dplyr::filter(!stringr::str_detect(param, "pi")) %>% - dplyr::filter(!stringr::str_detect(param, "log_Vc")) %>% + dplyr::filter(!stringr::str_detect(.data$param, "pi")) %>% + dplyr::filter(!stringr::str_detect(.data$param, "log_Vc")) %>% nrow() - 1 aic <- (-2 * logLik) + (2 * num_params) @@ -84,12 +84,12 @@ bic <- function(model) { num_params <- model$model$par %>% tibble::as_tibble() %>% dplyr::mutate(param = names(model$model$par)) %>% - dplyr::filter(!stringr::str_detect(param, "pi")) %>% - dplyr::filter(!stringr::str_detect(param, "log_Vc")) %>% + dplyr::filter(!stringr::str_detect(.data$param, "pi")) %>% + dplyr::filter(!stringr::str_detect(.data$param, "log_Vc")) %>% nrow() - 1 N <- model$data$data %>% - dplyr::distinct(resp_id) %>% + dplyr::distinct(.data$resp_id) %>% nrow() bic <- (-2 * logLik) + (log(N) * num_params) From 330067d9b97de1b7f8904c99fdc100f781f04bdd Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 14:37:46 -0500 Subject: [PATCH 08/13] add stringr as dependency --- DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION b/DESCRIPTION index 04cbe94..1171ba2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -44,6 +44,7 @@ Imports: rstan (>= 2.26.0), rstantools (>= 2.3.0), stats, + stringr, tibble, tidyr (>= 1.3.0) LinkingTo: From ebe5fe26114c854b520781030d1002414e9db27e Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 15:11:19 -0500 Subject: [PATCH 09/13] roxygenize --- man/model_evaluation.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/model_evaluation.Rd b/man/model_evaluation.Rd index f5b6e93..73860d0 100644 --- a/man/model_evaluation.Rd +++ b/man/model_evaluation.Rd @@ -10,7 +10,7 @@ \usage{ add_criterion( x, - criterion = c("loo", "waic", "aic", "bic"), + criterion = c("loo", "waic"), overwrite = FALSE, save = TRUE, ..., From f7ac91b719f763ee3a2fa6af80a44717f488998e Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 15:11:37 -0500 Subject: [PATCH 10/13] update defaults; test for any violation w/ err message --- R/model-evaluation.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/model-evaluation.R b/R/model-evaluation.R index 0019751..4236ddf 100644 --- a/R/model-evaluation.R +++ b/R/model-evaluation.R @@ -117,15 +117,15 @@ NULL #' @export #' @rdname model_evaluation -add_criterion <- function(x, criterion = c("loo", "waic", "aic", "bic"), +add_criterion <- function(x, criterion = c("loo", "waic"), overwrite = FALSE, save = TRUE, ..., r_eff = NA) { model <- check_model(x, required_class = "measrfit", name = "x") - if (model$method != "mcmc" & criterion %in% c("loo", "waic")) { + if (any(model$method != "mcmc" & criterion %in% c("loo", "waic"))) { rlang::abort("error_bad_method", message = glue::glue("LOO and WAIC model criteria are only ", "available for models estimated with ", "`method = \"mcmc\"`.")) - } else if (model$method != "optim" & criterion %in% c("aic", "bic")) { + } else if (any(model$method != "optim" & criterion %in% c("aic", "bic"))) { rlang::abort("error_bad_method", message = glue::glue("AIC and BIC model criteria are only ", "available for models estimated with ", From b7d666cc9ba8fb02ec81743f2dc4803db65dda30 Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Mon, 13 Jan 2025 15:11:45 -0500 Subject: [PATCH 11/13] update unit tests --- tests/testthat/test-ecpe.R | 12 ++---------- tests/testthat/test-mcmc.R | 3 +-- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/tests/testthat/test-ecpe.R b/tests/testthat/test-ecpe.R index 98d6430..e6cbf38 100644 --- a/tests/testthat/test-ecpe.R +++ b/tests/testthat/test-ecpe.R @@ -289,11 +289,7 @@ test_that("mcmc requirements error", { expect_s3_class(err, "error_bad_method") expect_match(err$message, "`method = \"mcmc\"`") - err <- rlang::catch_cnd(add_criterion(cmds_ecpe_lcdm, "loo")) - expect_s3_class(err, "error_bad_method") - expect_match(err$message, "`method = \"mcmc\"`") - - err <- rlang::catch_cnd(add_criterion(cmds_ecpe_lcdm, "waic")) + err <- rlang::catch_cnd(add_criterion(cmds_ecpe_lcdm)) expect_s3_class(err, "error_bad_method") expect_match(err$message, "`method = \"mcmc\"`") }) @@ -304,11 +300,7 @@ test_that("optim requirements error", { mcmc_mod <- cmds_ecpe_lcdm mcmc_mod$method <- "mcmc" - err <- rlang::catch_cnd(add_criterion(mcmc_mod, "aic")) - expect_s3_class(err, "error_bad_method") - expect_match(err$message, "`method = \"optim\"`") - - err <- rlang::catch_cnd(add_criterion(mcmc_mod, "bic")) + err <- rlang::catch_cnd(add_criterion(mcmc_mod, criterion = c("aic", "bic"))) expect_s3_class(err, "error_bad_method") expect_match(err$message, "`method = \"optim\"`") }) diff --git a/tests/testthat/test-mcmc.R b/tests/testthat/test-mcmc.R index 1d28534..39724c6 100644 --- a/tests/testthat/test-mcmc.R +++ b/tests/testthat/test-mcmc.R @@ -116,8 +116,7 @@ test_that("loo and waic can be added to model", { expect_equal(names(loo_model$criteria), "loo") expect_s3_class(loo_model$criteria$loo, "psis_loo") - lw_model <- add_criterion(loo_model, criterion = c("loo", "waic"), - overwrite = TRUE) + lw_model <- add_criterion(loo_model, overwrite = TRUE) expect_equal(names(lw_model$criteria), c("loo", "waic")) expect_s3_class(lw_model$criteria$loo, "psis_loo") expect_s3_class(lw_model$criteria$waic, "waic") From 540365b00315514873d6cccebb5a2e473cfe1f9d Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Wed, 15 Jan 2025 08:08:10 -0500 Subject: [PATCH 12/13] unit tests for aic & bic --- tests/testthat/test-utils-model-evaluation.R | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/testthat/test-utils-model-evaluation.R diff --git a/tests/testthat/test-utils-model-evaluation.R b/tests/testthat/test-utils-model-evaluation.R new file mode 100644 index 0000000..c93d0ea --- /dev/null +++ b/tests/testthat/test-utils-model-evaluation.R @@ -0,0 +1,20 @@ +test_that("aic works", { + num_params <- 101 + logLik <- -18474.98 + + exp_aic <- (-2 * logLik) + (2 * num_params) + aic_val <- aic(rstn_dino) + + expect_equal(exp_aic, aic_val) +}) + +test_that("bic works", { + num_params <- 101 + N <- 1000 + logLik <- -18474.98 + + exp_bic <- (-2 * logLik) + (log(N) * num_params) + bic_val <- bic(rstn_dino) + + expect_equal(exp_bic, bic_val) +}) From c97c070cde64d0b9d203f522713d592a46e981aa Mon Sep 17 00:00:00 2001 From: JeffreyCHoover Date: Tue, 21 Jan 2025 15:33:24 -0500 Subject: [PATCH 13/13] adding unit tests to improve test coverage --- tests/testthat/test-model-evaluation.R | 40 ++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/testthat/test-model-evaluation.R diff --git a/tests/testthat/test-model-evaluation.R b/tests/testthat/test-model-evaluation.R new file mode 100644 index 0000000..bb7a54a --- /dev/null +++ b/tests/testthat/test-model-evaluation.R @@ -0,0 +1,40 @@ +test_that("add criterion error messages work", { + err <- rlang::catch_cnd(add_criterion("test")) + expect_s3_class(err, "error_bad_argument") + expect_match(err$message, "must be an object with class measrfit") + + + err <- rlang::catch_cnd(add_criterion(rstn_dino, criterion = "waic")) + expect_s3_class(err, "error_bad_method") + expect_match(err$message, "LOO and WAIC model criteria are only available") + + err <- rlang::catch_cnd(add_criterion(rstn_dino, criterion = "waic")) + expect_s3_class(err, "error_bad_method") + expect_match(err$message, "LOO and WAIC model criteria are only available") + + test_dino <- rstn_dino + test_dino$method <- "mcmc" + err <- rlang::catch_cnd(add_criterion(test_dino, criterion = "aic")) + expect_s3_class(err, "error_bad_method") + expect_match(err$message, "AIC and BIC model criteria are only available") + + err <- rlang::catch_cnd(add_criterion(test_dino, criterion = "bic")) + expect_s3_class(err, "error_bad_method") + expect_match(err$message, "AIC and BIC model criteria are only available") +}) + +test_that("AIC works", { + rstn_dino <- add_criterion(rstn_dino, criterion = "aic") + + exp_aic <- 37151.96 + + expect_equal(rstn_dino$criteria$aic, exp_aic, tolerance = .0001) +}) + +test_that("BIC works", { + rstn_dino <- add_criterion(rstn_dino, criterion = "bic") + + exp_bic <- 37647.64 + + expect_equal(rstn_dino$criteria$bic, exp_bic, tolerance = .0001) +})