From d876a9f302be56511532c05609086a61c9d8011f Mon Sep 17 00:00:00 2001 From: John Zobolas Date: Fri, 15 Nov 2024 11:24:36 +0100 Subject: [PATCH] `cb_show` + `type` refactor (#157) * allow show_cb argument to be passed on to autoplot.precrec * refactor: assert_choice for type of plots * update docs * add one test for getting hint of available typet from an autoplot * update news --- NEWS.md | 4 ++++ R/BenchmarkResult.R | 9 +++++---- R/EnsembleFSResult.R | 2 +- R/Filter.R | 2 +- R/LearnerClassif.R | 2 +- R/LearnerClassifCVGlmnet.R | 2 ++ R/LearnerClassifGlmnet.R | 1 + R/LearnerClassifRpart.R | 1 + R/LearnerClustHierarchical.R | 2 +- R/LearnerRegr.R | 2 +- R/LearnerRegrCVGlmnet.R | 1 + R/LearnerRegrRpart.R | 1 + R/LearnerSurvCoxPH.R | 1 + R/OptimInstanceBatchSingleCrit.R | 2 ++ R/PredictionClassif.R | 2 +- R/PredictionClust.R | 2 +- R/PredictionRegr.R | 2 +- R/ResampleResult.R | 9 +++++---- R/TaskClassif.R | 2 +- R/TaskClust.R | 2 +- R/TaskRegr.R | 2 +- man/autoplot.BenchmarkResult.Rd | 3 ++- man/autoplot.ResampleResult.Rd | 5 +++-- tests/testthat/test_EnsembleFSResult.R | 4 ++++ 24 files changed, 43 insertions(+), 22 deletions(-) diff --git a/NEWS.md b/NEWS.md index b08f1e5a..e2cd4525 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # mlr3viz (development version) +- Allow passing parameters to `precrec::autoplot()` (eg `show_cb`) when plotting +`BenchmarkResult` and `ResampleResult` objects, using `type` = `roc` or `prc`. +- Refactor: wrong `type` in `autoplot`s now gives hints of which ones to use. + # mlr3viz 0.10.0 - Add plot for `LearnerSurvCoxPH`. diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index 798fe0cc..20b3a448 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -17,7 +17,8 @@ #' @template param_type #' @template param_measure #' @template param_theme -#' @param ... (ignored). +#' @param ... arguments passed on to [precrec::autoplot()] for `type = "roc"` or `"prc"`. +#' Useful to e.g. remove confidence bands with `show_cb = FALSE`. #' #' @return [ggplot2::ggplot()]. #' @@ -41,7 +42,7 @@ #' autoplot(object$clone(deep = TRUE)$filter(task_ids = "pima"), type = "roc") #' } autoplot.BenchmarkResult = function(object, type = "boxplot", measure = NULL, theme = theme_minimal(), ...) { - assert_string(type) + assert_choice(type, choices = c("boxplot", "roc", "prc"), null.ok = FALSE) task = object$tasks$task[[1L]] measure = mlr3::assert_measure(mlr3::as_measure(measure, task_type = task$task_type), task = task) @@ -74,7 +75,7 @@ autoplot.BenchmarkResult = function(object, type = "boxplot", measure = NULL, th }, "roc" = { - p = plot_precrec(object, curvetype = "ROC") + p = plot_precrec(object, curvetype = "ROC", ...) p$layers[[1]]$mapping = aes(color = modname, fill = modname) # fill confidence bounds p + @@ -84,7 +85,7 @@ autoplot.BenchmarkResult = function(object, type = "boxplot", measure = NULL, th }, "prc" = { - p = plot_precrec(object, curvetype = "PRC") + p = plot_precrec(object, curvetype = "PRC", ...) # fill confidence bounds p$layers[[1]]$mapping = aes(color = modname, fill = modname) p + diff --git a/R/EnsembleFSResult.R b/R/EnsembleFSResult.R index 9d4daf9c..77d0ea4f 100644 --- a/R/EnsembleFSResult.R +++ b/R/EnsembleFSResult.R @@ -77,7 +77,7 @@ autoplot.EnsembleFSResult = function( theme = theme_minimal(), ... ) { - assert_string(type) + assert_choice(type, choices = c("pareto", "performance", "n_features", "stability"), null.ok = FALSE) assert_choice(pareto_front, choices = c("stepwise", "estimated", "none")) result = object$result measure_id = object$measure diff --git a/R/Filter.R b/R/Filter.R index ac1625ba..682b6a63 100644 --- a/R/Filter.R +++ b/R/Filter.R @@ -31,7 +31,7 @@ #' autoplot(f, n = 5) #' } autoplot.Filter = function(object, type = "boxplot", n = Inf, theme = theme_minimal(), ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("boxplot"), null.ok = FALSE) data = head(fortify(object), n) diff --git a/R/LearnerClassif.R b/R/LearnerClassif.R index d720a398..797427a4 100644 --- a/R/LearnerClassif.R +++ b/R/LearnerClassif.R @@ -33,7 +33,7 @@ #' } #' } autoplot.LearnerClassif = function(object, type = "prediction", task, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("prediction"), null.ok = FALSE) switch(type, "prediction" = { diff --git a/R/LearnerClassifCVGlmnet.R b/R/LearnerClassifCVGlmnet.R index d0dc7733..c95f6772 100644 --- a/R/LearnerClassifCVGlmnet.R +++ b/R/LearnerClassifCVGlmnet.R @@ -1,6 +1,8 @@ #' @rdname autoplot.LearnerClassifGlmnet #' @export autoplot.LearnerClassifCVGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint + assert_choice(type, choices = c("prediction", "ggfortify"), null.ok = FALSE) + switch(type, "prediction" = { NextMethod() diff --git a/R/LearnerClassifGlmnet.R b/R/LearnerClassifGlmnet.R index 6573f47c..2ea040e9 100644 --- a/R/LearnerClassifGlmnet.R +++ b/R/LearnerClassifGlmnet.R @@ -42,6 +42,7 @@ #' autoplot(learner, type = "ggfortify") #' } autoplot.LearnerClassifGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint + assert_choice(type, choices = c("prediction", "ggfortify"), null.ok = FALSE) assert_has_model(object) switch(type, diff --git a/R/LearnerClassifRpart.R b/R/LearnerClassifRpart.R index 1205ec4a..6c09359a 100644 --- a/R/LearnerClassifRpart.R +++ b/R/LearnerClassifRpart.R @@ -38,6 +38,7 @@ #' autoplot(learner, type = "ggparty") #' } autoplot.LearnerClassifRpart = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint + assert_choice(type, choices = c("prediction", "ggparty"), null.ok = FALSE) assert_has_model(object) switch(type, diff --git a/R/LearnerClustHierarchical.R b/R/LearnerClustHierarchical.R index e84115f2..40aa5b04 100644 --- a/R/LearnerClustHierarchical.R +++ b/R/LearnerClustHierarchical.R @@ -46,7 +46,7 @@ #' autoplot(learner, type = "scree") #' } autoplot.LearnerClustHierarchical = function(object, type = "dend", task = NULL, theme = theme_minimal(), theme_dendro = TRUE, ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("dend", "scree"), null.ok = FALSE) if (is.null(object$model)) { stopf("Learner '%s' must be trained first", object$id) diff --git a/R/LearnerRegr.R b/R/LearnerRegr.R index 7f55bc03..32b51869 100644 --- a/R/LearnerRegr.R +++ b/R/LearnerRegr.R @@ -33,7 +33,7 @@ #' } #' } autoplot.LearnerRegr = function(object, type = "prediction", task, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("prediction"), null.ok = FALSE) switch(type, "prediction" = { diff --git a/R/LearnerRegrCVGlmnet.R b/R/LearnerRegrCVGlmnet.R index 104d7276..0400d74e 100644 --- a/R/LearnerRegrCVGlmnet.R +++ b/R/LearnerRegrCVGlmnet.R @@ -1,6 +1,7 @@ #' @rdname autoplot.LearnerClassifGlmnet #' @export autoplot.LearnerRegrCVGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint + assert_choice(type, choices = c("prediction", "ggfortify"), null.ok = FALSE) switch(type, "prediction" = { NextMethod() diff --git a/R/LearnerRegrRpart.R b/R/LearnerRegrRpart.R index a6ef8123..f3a7b691 100644 --- a/R/LearnerRegrRpart.R +++ b/R/LearnerRegrRpart.R @@ -1,6 +1,7 @@ #' @export #' @rdname autoplot.LearnerClassifRpart autoplot.LearnerRegrRpart = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint + assert_choice(type, choices = c("prediction", "ggparty"), null.ok = FALSE) assert_has_model(object) switch(type, diff --git a/R/LearnerSurvCoxPH.R b/R/LearnerSurvCoxPH.R index 4f9a57b7..222b53a4 100644 --- a/R/LearnerSurvCoxPH.R +++ b/R/LearnerSurvCoxPH.R @@ -30,6 +30,7 @@ #' } #' } autoplot.LearnerSurvCoxPH = function(object, type = "ggforest", ...) { + assert_choice(type, choices = c("ggforest"), null.ok = FALSE) assert_class(object, classes = "LearnerSurvCoxPH", null.ok = FALSE) assert_has_model(object) diff --git a/R/OptimInstanceBatchSingleCrit.R b/R/OptimInstanceBatchSingleCrit.R index 19e452ac..8ce7b594 100644 --- a/R/OptimInstanceBatchSingleCrit.R +++ b/R/OptimInstanceBatchSingleCrit.R @@ -85,6 +85,8 @@ #' print(autoplot(instance, type = "incumbent")) #' } autoplot.OptimInstanceBatchSingleCrit = function(object, type = "marginal", cols_x = NULL, trafo = FALSE, learner = mlr3::lrn("regr.ranger"), grid_resolution = 100, batch = NULL, theme = theme_minimal(), ...) { # nolint + assert_choice(type, choices = c("marginal", "performance", "parameter", "parallel", + "points", "surface", "pairs", "incumbent"), null.ok = FALSE) assert_subset(cols_x, c(object$archive$cols_x, paste0("x_domain_", object$archive$cols_x))) assert_flag(trafo) diff --git a/R/PredictionClassif.R b/R/PredictionClassif.R index 69d03f77..621b1430 100644 --- a/R/PredictionClassif.R +++ b/R/PredictionClassif.R @@ -41,7 +41,7 @@ #' } #' } autoplot.PredictionClassif = function(object, type = "stacked", measure = NULL, theme = theme_minimal(), ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("stacked", "roc", "prc", "threshold"), null.ok = FALSE) switch(type, "stacked" = { diff --git a/R/PredictionClust.R b/R/PredictionClust.R index c5211978..6ebb76b1 100644 --- a/R/PredictionClust.R +++ b/R/PredictionClust.R @@ -39,7 +39,7 @@ #' autoplot(object, task) #' } autoplot.PredictionClust = function(object, task, row_ids = NULL, type = "scatter", theme = theme_minimal(), ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("scatter", "sil", "pca"), null.ok = FALSE) switch(type, "scatter" = { diff --git a/R/PredictionRegr.R b/R/PredictionRegr.R index a0e589ca..73212166 100644 --- a/R/PredictionRegr.R +++ b/R/PredictionRegr.R @@ -50,7 +50,7 @@ #' } #' } autoplot.PredictionRegr = function(object, type = "xy", binwidth = NULL, theme = theme_minimal(), quantile = 1.96, ...) { - checkmate::assert_string(type) + assert_choice(type, choices = c("xy", "histogram", "residual", "confidence"), null.ok = FALSE) switch(type, "xy" = { diff --git a/R/ResampleResult.R b/R/ResampleResult.R index aec4cd9b..3f00253f 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -13,7 +13,7 @@ #' * `"prc"`: Precision recall curve. #' See `"roc"`. #' * `"prediction"`: Plots the learner prediction for a grid of points. -#' Needs models to be stored. Set `store_models = TRUE` for `[mlr3::resample]`. +#' Needs models to be stored. Set `store_models = TRUE` for [mlr3::resample()]. #' For classification, we support tasks with exactly two features and learners with `predict_type=` set to `"response"` or `"prob"`. #' For regression, we support tasks with one or two features. #' For tasks with one feature we can print confidence bounds if the predict type of the learner was set to `"se"`. @@ -29,7 +29,8 @@ #' @param binwidth (`integer(1)`)\cr #' Width of the bins for the histogram. #' @template param_theme -#' @param ... (ignored). +#' @param ... arguments passed on to [precrec::autoplot()] for `type = "roc"` or `"prc"`. +#' Useful to e.g. remove confidence bands with `show_cb = FALSE`. #' #' @return [ggplot2::ggplot()]. #' @@ -73,7 +74,7 @@ #' } #' } autoplot.ResampleResult = function(object, type = "boxplot", measure = NULL, predict_sets = "test", binwidth = NULL, theme = theme_minimal(), ...) { - assert_string(type) + assert_choice(type, choices = c("boxplot", "histogram", "prediction", "roc", "prc"), null.ok = FALSE) task = object$task measure = mlr3::assert_measure(mlr3::as_measure(measure, task_type = task$task_type), task = task) @@ -122,7 +123,7 @@ autoplot.ResampleResult = function(object, type = "boxplot", measure = NULL, pre }, "prc" = { - p = plot_precrec(object, curvetype = "PRC") + p = plot_precrec(object, curvetype = "PRC", ...) # fill confidence bounds p$layers[[1]]$mapping = aes(color = modname, fill = modname) p + diff --git a/R/TaskClassif.R b/R/TaskClassif.R index bfe2dcb8..b13e55ac 100644 --- a/R/TaskClassif.R +++ b/R/TaskClassif.R @@ -33,7 +33,7 @@ #' autoplot(task, type = "duo") #' } autoplot.TaskClassif = function(object, type = "target", theme = theme_minimal(), ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("target", "duo", "pairs"), null.ok = FALSE) target = object$target_names diff --git a/R/TaskClust.R b/R/TaskClust.R index e1cb5404..f289a6cf 100644 --- a/R/TaskClust.R +++ b/R/TaskClust.R @@ -27,7 +27,7 @@ #' autoplot(task) #' } autoplot.TaskClust = function(object, type = "pairs", theme = theme_minimal(), ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("pairs"), null.ok = FALSE) switch(type, "pairs" = { diff --git a/R/TaskRegr.R b/R/TaskRegr.R index 07a5e16e..60597b9a 100644 --- a/R/TaskRegr.R +++ b/R/TaskRegr.R @@ -30,7 +30,7 @@ #' autoplot(task, type = "pairs") #' } autoplot.TaskRegr = function(object, type = "target", theme = theme_minimal(), ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("target", "pairs"), null.ok = FALSE) switch(type, "target" = { diff --git a/man/autoplot.BenchmarkResult.Rd b/man/autoplot.BenchmarkResult.Rd index 028022b2..b92cfd51 100644 --- a/man/autoplot.BenchmarkResult.Rd +++ b/man/autoplot.BenchmarkResult.Rd @@ -24,7 +24,8 @@ Performance measure to use.} \item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} -\item{...}{(ignored).} +\item{...}{arguments passed on to \code{\link[precrec:autoplot]{precrec::autoplot()}} for \code{type = "roc"} or \code{"prc"}. +Useful to e.g. remove confidence bands with \code{show_cb = FALSE}.} } \value{ \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}}. diff --git a/man/autoplot.ResampleResult.Rd b/man/autoplot.ResampleResult.Rd index 985ac782..99104fe7 100644 --- a/man/autoplot.ResampleResult.Rd +++ b/man/autoplot.ResampleResult.Rd @@ -34,7 +34,8 @@ Width of the bins for the histogram.} \item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} -\item{...}{(ignored).} +\item{...}{arguments passed on to \code{\link[precrec:autoplot]{precrec::autoplot()}} for \code{type = "roc"} or \code{"prc"}. +Useful to e.g. remove confidence bands with \code{show_cb = FALSE}.} } \value{ \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}}. @@ -52,7 +53,7 @@ Requires package \CRANpkg{precrec}. \item \code{"prc"}: Precision recall curve. See \code{"roc"}. \item \code{"prediction"}: Plots the learner prediction for a grid of points. -Needs models to be stored. Set \code{store_models = TRUE} for \verb{[mlr3::resample]}. +Needs models to be stored. Set \code{store_models = TRUE} for \code{\link[mlr3:resample]{mlr3::resample()}}. For classification, we support tasks with exactly two features and learners with \verb{predict_type=} set to \code{"response"} or \code{"prob"}. For regression, we support tasks with one or two features. For tasks with one feature we can print confidence bounds if the predict type of the learner was set to \code{"se"}. diff --git a/tests/testthat/test_EnsembleFSResult.R b/tests/testthat/test_EnsembleFSResult.R index 69c3b76d..62d8bb4b 100644 --- a/tests/testthat/test_EnsembleFSResult.R +++ b/tests/testthat/test_EnsembleFSResult.R @@ -18,6 +18,10 @@ test_that("autoplot ResampleResult", { ) efsr = mlr3fselect::EnsembleFSResult$new(result = result, features = paste0("V", 1:20), measure_id = "classif.ce") + + # wrong type gives hint of types a user can input + expect_error(autoplot(efsr, type = "XYZ"), regexp = "Must be element of set") + # pareto (stepwise) p = autoplot(efsr) expect_true(is.ggplot(p))