From 572d3de0d2bc3357ca6cb82b26c8ac4eae03d686 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 18 Jul 2024 13:32:18 +0200 Subject: [PATCH] feat: autoplot for confidence intervals --- DESCRIPTION | 6 ++-- NEWS.md | 1 + R/BenchmarkResult.R | 38 ++++++++++++++++++++++++ man/autoplot.BenchmarkResult.Rd | 1 + man/autoplot.EnsembleFSResult.Rd | 4 +-- man/autoplot.LearnerClustHierarchical.Rd | 2 +- man/autoplot.LearnerSurvCoxPH.Rd | 4 +-- man/mlr3viz-package.Rd | 4 +-- tests/testthat/test_BenchmarkResult.R | 16 ++++++++++ 9 files changed, 67 insertions(+), 9 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 10b02e33..8b59d9ad 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -49,6 +49,7 @@ Suggests: mlr3cluster, mlr3filters, mlr3fselect (>= 1.0.0), + mlr3inference, mlr3learners, mlr3tuning (>= 1.0.0), paradox, @@ -64,7 +65,8 @@ Suggests: survminer, mlr3proba (>= 0.6.3) Remotes: - mlr-org/mlr3proba + mlr-org/mlr3proba, + mlr-org/mlr3inference Additional_repositories: https://mlr-org.r-universe.dev Config/testthat/edition: 3 @@ -72,7 +74,7 @@ Config/testthat/parallel: true Encoding: UTF-8 NeedsCompilation: no Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Collate: 'BenchmarkResult.R' 'Filter.R' diff --git a/NEWS.md b/NEWS.md index 2e2d754e..c01caeb4 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,7 @@ # mlr3viz (development version) - Add plot for `LearnerSurvCoxPH`. +- Add plot for confidence intervals (`mlr3inference`) # mlr3viz 0.9.0 diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index 944eb47a..27ad1e9c 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -12,6 +12,7 @@ #' Requires package \CRANpkg{precrec}. #' * `"prc"`: Precision recall curve. #' See `"roc"`. +#' * `"ci"`: Plot confidence intervals. Pass a `msr("ci", ...)` from the `mlr3inference` package as argument `measure`. #' #' @param object ([mlr3::BenchmarkResult]). #' @template param_type @@ -45,6 +46,43 @@ autoplot.BenchmarkResult = function(object, type = "boxplot", measure = NULL, th task = object$tasks$task[[1L]] measure = mlr3::assert_measure(mlr3::as_measure(measure, task_type = task$task_type), task = task) + + if (identical(type, "ci")) { + mlr3misc::require_namespaces("mlr3inference") + + assert_class(measure, "MeasureAbstractCi") + mid = measure$id + + tbl = object$aggregate(measure) + + tmp = map(object$resamplings$resampling, function(x) { + list(class(x), x$param_set$values) + }) + + if (length(unique(tmp)) != 1) { + stopf("Plot of type 'ci' requires exactly one resampling method") + } + + # static checker + .data = NULL + task_id = NULL + p = ggplot(tbl, aes(x = .data[["learner_id"]], y = .data[[mid]])) + + geom_point() + + geom_errorbar(aes(ymin = .data[[paste0(mid, ".lower")]], ymax = .data[[paste0(mid, ".upper")]]), width = 0.2) + + facet_wrap(vars(task_id), scales = "free_y") + + labs( + title = sprintf("Confidence Intervals for alpha = %s", measure$param_set$values$alpha), + x = "Learner", + y = paste0(measure$measure$id) + ) + + theme + + theme( + axis.text.x = element_text(angle = 45, hjust = 1), + axis.title.x = element_blank() + ) + return(p) + } + measure_id = measure$id tab = fortify(object, measure = measure) tab$nr = sprintf("%09d", tab$nr) diff --git a/man/autoplot.BenchmarkResult.Rd b/man/autoplot.BenchmarkResult.Rd index 028022b2..7e5776b2 100644 --- a/man/autoplot.BenchmarkResult.Rd +++ b/man/autoplot.BenchmarkResult.Rd @@ -41,6 +41,7 @@ Note that you can subset any \link[mlr3:BenchmarkResult]{mlr3::BenchmarkResult} Requires package \CRANpkg{precrec}. \item \code{"prc"}: Precision recall curve. See \code{"roc"}. +\item \code{"ci"}: Plot confidence intervals. Pass a \code{msr("ci", ...)} from the \code{mlr3inference} package as argument \code{measure}. } } \examples{ diff --git a/man/autoplot.EnsembleFSResult.Rd b/man/autoplot.EnsembleFSResult.Rd index c42360bf..7620031b 100644 --- a/man/autoplot.EnsembleFSResult.Rd +++ b/man/autoplot.EnsembleFSResult.Rd @@ -15,7 +15,7 @@ ) } \arguments{ -\item{object}{(\link[mlr3fselect:ensemble_fs_result]{mlr3fselect::EnsembleFSResult}).} +\item{object}{(\link[mlr3fselect:EnsembleFSResult]{mlr3fselect::EnsembleFSResult}).} \item{type}{(character(1)):\cr Type of the plot. See description.} @@ -41,7 +41,7 @@ The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by defaul \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}}. } \description{ -Visualizations for \link[mlr3fselect:ensemble_fs_result]{EnsembleFSResult}. +Visualizations for \link[mlr3fselect:EnsembleFSResult]{EnsembleFSResult}. The argument \code{type} determines the type of plot generated. The available options are: \itemize{ diff --git a/man/autoplot.LearnerClustHierarchical.Rd b/man/autoplot.LearnerClustHierarchical.Rd index 0e6b885b..6fe617b2 100644 --- a/man/autoplot.LearnerClustHierarchical.Rd +++ b/man/autoplot.LearnerClustHierarchical.Rd @@ -14,7 +14,7 @@ ) } \arguments{ -\item{object}{(\link[mlr3cluster:mlr_learners_clust.agnes]{mlr3cluster::LearnerClustAgnes} | \link[mlr3cluster:mlr_learners_clust.diana]{mlr3cluster::LearnerClustDiana} | \link[mlr3cluster:mlr_learners_clust.hclust]{mlr3cluster::LearnerClustHclust}).} +\item{object}{(\link[mlr3cluster:LearnerClustAgnes]{mlr3cluster::LearnerClustAgnes} | \link[mlr3cluster:LearnerClustDiana]{mlr3cluster::LearnerClustDiana} | \link[mlr3cluster:LearnerClustHclust]{mlr3cluster::LearnerClustHclust}).} \item{type}{(character(1)):\cr Type of the plot. See description.} diff --git a/man/autoplot.LearnerSurvCoxPH.Rd b/man/autoplot.LearnerSurvCoxPH.Rd index 2b642ec7..cf1f4c83 100644 --- a/man/autoplot.LearnerSurvCoxPH.Rd +++ b/man/autoplot.LearnerSurvCoxPH.Rd @@ -7,7 +7,7 @@ \method{autoplot}{LearnerSurvCoxPH}(object, type = "ggforest", ...) } \arguments{ -\item{object}{(\link[mlr3proba:mlr_learners_surv.coxph]{mlr3proba::LearnerSurvCoxPH}).} +\item{object}{(\link[mlr3proba:LearnerSurvCoxPH]{mlr3proba::LearnerSurvCoxPH}).} \item{type}{(character(1)):\cr Type of the plot. See description.} @@ -18,7 +18,7 @@ Type of the plot. See description.} \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}}. } \description{ -Visualizations for \link[mlr3proba:mlr_learners_surv.coxph]{mlr3proba::LearnerSurvCoxPH}. +Visualizations for \link[mlr3proba:LearnerSurvCoxPH]{mlr3proba::LearnerSurvCoxPH}. The argument \code{type} controls what kind of plot is drawn. The only possible choice right now is \code{"ggforest"} (default) which is a diff --git a/man/mlr3viz-package.Rd b/man/mlr3viz-package.Rd index 51f8644e..032e390b 100644 --- a/man/mlr3viz-package.Rd +++ b/man/mlr3viz-package.Rd @@ -20,13 +20,13 @@ Useful links: } \author{ -\strong{Maintainer}: Michel Lang \email{michellang@gmail.com} (\href{https://orcid.org/0000-0001-9754-0393}{ORCID}) +\strong{Maintainer}: Marc Becker \email{marcbecker@posteo.de} (\href{https://orcid.org/0000-0002-8115-0400}{ORCID}) Authors: \itemize{ + \item Michel Lang \email{michellang@gmail.com} (\href{https://orcid.org/0000-0001-9754-0393}{ORCID}) \item Patrick Schratz \email{patrick.schratz@gmail.com} (\href{https://orcid.org/0000-0003-0748-6624}{ORCID}) \item Raphael Sonabend \email{raphael.sonabend.15@ucl.ac.uk} (\href{https://orcid.org/0000-0001-9225-4654}{ORCID}) - \item Marc Becker \email{marcbecker@posteo.de} (\href{https://orcid.org/0000-0002-8115-0400}{ORCID}) \item Jakob Richter \email{jakob1richter@gmail.com} (\href{https://orcid.org/0000-0003-4481-5554}{ORCID}) \item John Zobolas \email{bblodfon@gmail.com} (\href{https://orcid.org/0000-0002-3609-8674}{ORCID}) } diff --git a/tests/testthat/test_BenchmarkResult.R b/tests/testthat/test_BenchmarkResult.R index 47d29a27..d80f0442 100644 --- a/tests/testthat/test_BenchmarkResult.R +++ b/tests/testthat/test_BenchmarkResult.R @@ -48,3 +48,19 @@ test_that("holdout roc plot (#54)", { expect_doppelganger("bmr_holdout_roc", p) }) + +skip_if_not_installed("mlr3inference") +skip_if_not_installed("rpart") + +test_that("CI plot", { + bmr = benchmark(benchmark_grid(tsks(c("mtcars", "boston_housing")), + lrns(c("regr.featureless", "regr.rpart")), rsmp("holdout"))) + + p = autoplot(bmr, "ci", msr("ci", "regr.mse")) + expect_true(is.ggplot(p)) + expect_doppelganger("bmr_holdout_ci", p) + + bmr = benchmark(benchmark_grid(tsk("iris"), lrn("classif.rpart"), + rsmps(c("holdout", "cv")))) + expect_error(autoplot(bmr, "ci", msr("ci", "classif.acc")), "one resampling method") +})