Skip to content

Commit

Permalink
cb_show + type refactor (#157)
Browse files Browse the repository at this point in the history
* allow show_cb argument to be passed on to autoplot.precrec

* refactor: assert_choice for type of plots

* update docs

* add one test for getting hint of available typet from an autoplot

* update news
  • Loading branch information
bblodfon authored Nov 15, 2024
1 parent 63ee081 commit d876a9f
Show file tree
Hide file tree
Showing 24 changed files with 43 additions and 22 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# mlr3viz (development version)

- Allow passing parameters to `precrec::autoplot()` (eg `show_cb`) when plotting
`BenchmarkResult` and `ResampleResult` objects, using `type` = `roc` or `prc`.
- Refactor: wrong `type` in `autoplot`s now gives hints of which ones to use.

# mlr3viz 0.10.0

- Add plot for `LearnerSurvCoxPH`.
Expand Down
9 changes: 5 additions & 4 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#' @template param_type
#' @template param_measure
#' @template param_theme
#' @param ... (ignored).
#' @param ... arguments passed on to [precrec::autoplot()] for `type = "roc"` or `"prc"`.
#' Useful to e.g. remove confidence bands with `show_cb = FALSE`.
#'
#' @return [ggplot2::ggplot()].
#'
Expand All @@ -41,7 +42,7 @@
#' autoplot(object$clone(deep = TRUE)$filter(task_ids = "pima"), type = "roc")
#' }
autoplot.BenchmarkResult = function(object, type = "boxplot", measure = NULL, theme = theme_minimal(), ...) {
assert_string(type)
assert_choice(type, choices = c("boxplot", "roc", "prc"), null.ok = FALSE)

task = object$tasks$task[[1L]]
measure = mlr3::assert_measure(mlr3::as_measure(measure, task_type = task$task_type), task = task)
Expand Down Expand Up @@ -74,7 +75,7 @@ autoplot.BenchmarkResult = function(object, type = "boxplot", measure = NULL, th
},

"roc" = {
p = plot_precrec(object, curvetype = "ROC")
p = plot_precrec(object, curvetype = "ROC", ...)
p$layers[[1]]$mapping = aes(color = modname, fill = modname)
# fill confidence bounds
p +
Expand All @@ -84,7 +85,7 @@ autoplot.BenchmarkResult = function(object, type = "boxplot", measure = NULL, th
},

"prc" = {
p = plot_precrec(object, curvetype = "PRC")
p = plot_precrec(object, curvetype = "PRC", ...)
# fill confidence bounds
p$layers[[1]]$mapping = aes(color = modname, fill = modname)
p +
Expand Down
2 changes: 1 addition & 1 deletion R/EnsembleFSResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ autoplot.EnsembleFSResult = function(
theme = theme_minimal(),
...
) {
assert_string(type)
assert_choice(type, choices = c("pareto", "performance", "n_features", "stability"), null.ok = FALSE)
assert_choice(pareto_front, choices = c("stepwise", "estimated", "none"))
result = object$result
measure_id = object$measure
Expand Down
2 changes: 1 addition & 1 deletion R/Filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#' autoplot(f, n = 5)
#' }
autoplot.Filter = function(object, type = "boxplot", n = Inf, theme = theme_minimal(), ...) { # nolint
assert_string(type)
assert_choice(type, choices = c("boxplot"), null.ok = FALSE)

data = head(fortify(object), n)

Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#' }
#' }
autoplot.LearnerClassif = function(object, type = "prediction", task, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_string(type)
assert_choice(type, choices = c("prediction"), null.ok = FALSE)

switch(type,
"prediction" = {
Expand Down
2 changes: 2 additions & 0 deletions R/LearnerClassifCVGlmnet.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#' @rdname autoplot.LearnerClassifGlmnet
#' @export
autoplot.LearnerClassifCVGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_choice(type, choices = c("prediction", "ggfortify"), null.ok = FALSE)

switch(type,
"prediction" = {
NextMethod()
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClassifGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#' autoplot(learner, type = "ggfortify")
#' }
autoplot.LearnerClassifGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_choice(type, choices = c("prediction", "ggfortify"), null.ok = FALSE)
assert_has_model(object)

switch(type,
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#' autoplot(learner, type = "ggparty")
#' }
autoplot.LearnerClassifRpart = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_choice(type, choices = c("prediction", "ggparty"), null.ok = FALSE)
assert_has_model(object)

switch(type,
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustHierarchical.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
#' autoplot(learner, type = "scree")
#' }
autoplot.LearnerClustHierarchical = function(object, type = "dend", task = NULL, theme = theme_minimal(), theme_dendro = TRUE, ...) { # nolint
assert_string(type)
assert_choice(type, choices = c("dend", "scree"), null.ok = FALSE)

if (is.null(object$model)) {
stopf("Learner '%s' must be trained first", object$id)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#' }
#' }
autoplot.LearnerRegr = function(object, type = "prediction", task, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_string(type)
assert_choice(type, choices = c("prediction"), null.ok = FALSE)

switch(type,
"prediction" = {
Expand Down
1 change: 1 addition & 0 deletions R/LearnerRegrCVGlmnet.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#' @rdname autoplot.LearnerClassifGlmnet
#' @export
autoplot.LearnerRegrCVGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_choice(type, choices = c("prediction", "ggfortify"), null.ok = FALSE)
switch(type,
"prediction" = {
NextMethod()
Expand Down
1 change: 1 addition & 0 deletions R/LearnerRegrRpart.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#' @export
#' @rdname autoplot.LearnerClassifRpart
autoplot.LearnerRegrRpart = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_choice(type, choices = c("prediction", "ggparty"), null.ok = FALSE)
assert_has_model(object)

switch(type,
Expand Down
1 change: 1 addition & 0 deletions R/LearnerSurvCoxPH.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#' }
#' }
autoplot.LearnerSurvCoxPH = function(object, type = "ggforest", ...) {
assert_choice(type, choices = c("ggforest"), null.ok = FALSE)
assert_class(object, classes = "LearnerSurvCoxPH", null.ok = FALSE)
assert_has_model(object)

Expand Down
2 changes: 2 additions & 0 deletions R/OptimInstanceBatchSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
#' print(autoplot(instance, type = "incumbent"))
#' }
autoplot.OptimInstanceBatchSingleCrit = function(object, type = "marginal", cols_x = NULL, trafo = FALSE, learner = mlr3::lrn("regr.ranger"), grid_resolution = 100, batch = NULL, theme = theme_minimal(), ...) { # nolint
assert_choice(type, choices = c("marginal", "performance", "parameter", "parallel",
"points", "surface", "pairs", "incumbent"), null.ok = FALSE)
assert_subset(cols_x, c(object$archive$cols_x, paste0("x_domain_", object$archive$cols_x)))
assert_flag(trafo)

Expand Down
2 changes: 1 addition & 1 deletion R/PredictionClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#' }
#' }
autoplot.PredictionClassif = function(object, type = "stacked", measure = NULL, theme = theme_minimal(), ...) { # nolint
assert_string(type)
assert_choice(type, choices = c("stacked", "roc", "prc", "threshold"), null.ok = FALSE)

switch(type,
"stacked" = {
Expand Down
2 changes: 1 addition & 1 deletion R/PredictionClust.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#' autoplot(object, task)
#' }
autoplot.PredictionClust = function(object, task, row_ids = NULL, type = "scatter", theme = theme_minimal(), ...) { # nolint
assert_string(type)
assert_choice(type, choices = c("scatter", "sil", "pca"), null.ok = FALSE)

switch(type,
"scatter" = {
Expand Down
2 changes: 1 addition & 1 deletion R/PredictionRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
#' }
#' }
autoplot.PredictionRegr = function(object, type = "xy", binwidth = NULL, theme = theme_minimal(), quantile = 1.96, ...) {
checkmate::assert_string(type)
assert_choice(type, choices = c("xy", "histogram", "residual", "confidence"), null.ok = FALSE)

switch(type,
"xy" = {
Expand Down
9 changes: 5 additions & 4 deletions R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' * `"prc"`: Precision recall curve.
#' See `"roc"`.
#' * `"prediction"`: Plots the learner prediction for a grid of points.
#' Needs models to be stored. Set `store_models = TRUE` for `[mlr3::resample]`.
#' Needs models to be stored. Set `store_models = TRUE` for [mlr3::resample()].
#' For classification, we support tasks with exactly two features and learners with `predict_type=` set to `"response"` or `"prob"`.
#' For regression, we support tasks with one or two features.
#' For tasks with one feature we can print confidence bounds if the predict type of the learner was set to `"se"`.
Expand All @@ -29,7 +29,8 @@
#' @param binwidth (`integer(1)`)\cr
#' Width of the bins for the histogram.
#' @template param_theme
#' @param ... (ignored).
#' @param ... arguments passed on to [precrec::autoplot()] for `type = "roc"` or `"prc"`.
#' Useful to e.g. remove confidence bands with `show_cb = FALSE`.
#'
#' @return [ggplot2::ggplot()].
#'
Expand Down Expand Up @@ -73,7 +74,7 @@
#' }
#' }
autoplot.ResampleResult = function(object, type = "boxplot", measure = NULL, predict_sets = "test", binwidth = NULL, theme = theme_minimal(), ...) {
assert_string(type)
assert_choice(type, choices = c("boxplot", "histogram", "prediction", "roc", "prc"), null.ok = FALSE)

task = object$task
measure = mlr3::assert_measure(mlr3::as_measure(measure, task_type = task$task_type), task = task)
Expand Down Expand Up @@ -122,7 +123,7 @@ autoplot.ResampleResult = function(object, type = "boxplot", measure = NULL, pre
},

"prc" = {
p = plot_precrec(object, curvetype = "PRC")
p = plot_precrec(object, curvetype = "PRC", ...)
# fill confidence bounds
p$layers[[1]]$mapping = aes(color = modname, fill = modname)
p +
Expand Down
2 changes: 1 addition & 1 deletion R/TaskClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#' autoplot(task, type = "duo")
#' }
autoplot.TaskClassif = function(object, type = "target", theme = theme_minimal(), ...) { # nolint
assert_string(type)
assert_choice(type, choices = c("target", "duo", "pairs"), null.ok = FALSE)

target = object$target_names

Expand Down
2 changes: 1 addition & 1 deletion R/TaskClust.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#' autoplot(task)
#' }
autoplot.TaskClust = function(object, type = "pairs", theme = theme_minimal(), ...) { # nolint
assert_string(type)
assert_choice(type, choices = c("pairs"), null.ok = FALSE)

switch(type,
"pairs" = {
Expand Down
2 changes: 1 addition & 1 deletion R/TaskRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#' autoplot(task, type = "pairs")
#' }
autoplot.TaskRegr = function(object, type = "target", theme = theme_minimal(), ...) { # nolint
assert_string(type)
assert_choice(type, choices = c("target", "pairs"), null.ok = FALSE)

switch(type,
"target" = {
Expand Down
3 changes: 2 additions & 1 deletion man/autoplot.BenchmarkResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions man/autoplot.ResampleResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions tests/testthat/test_EnsembleFSResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ test_that("autoplot ResampleResult", {
)

efsr = mlr3fselect::EnsembleFSResult$new(result = result, features = paste0("V", 1:20), measure_id = "classif.ce")

# wrong type gives hint of types a user can input
expect_error(autoplot(efsr, type = "XYZ"), regexp = "Must be element of set")

# pareto (stepwise)
p = autoplot(efsr)
expect_true(is.ggplot(p))
Expand Down

0 comments on commit d876a9f

Please sign in to comment.