diff --git a/DESCRIPTION b/DESCRIPTION
index 8bb5eaa2..2c75ab69 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -65,7 +65,7 @@ Remotes:
mlr-org/mlr3,
mlr-org/bbotk,
mlr-org/mlr3tuning,
- mlr-org/mlr3fselect
+ mlr-org/mlr3fselect@efs_updates
Config/testthat/edition: 3
Config/testthat/parallel: true
Encoding: UTF-8
diff --git a/R/EnsembleFSResult.R b/R/EnsembleFSResult.R
index 44f62d08..9d4daf9c 100644
--- a/R/EnsembleFSResult.R
+++ b/R/EnsembleFSResult.R
@@ -31,6 +31,7 @@
#' @template param_theme
#' @param stability_args (`list`)\cr
#' Additional arguments passed to the stability measure function.
+#' @param ... (ignored).
#'
#' @return [ggplot2::ggplot()].
#' @examples
@@ -67,93 +68,97 @@
#' }
#' }
#' @export
-autoplot.EnsembleFSResult = function(object, type = "pareto",
- pareto_front = "stepwise",
- stability_measure = "jaccard",
- stability_args = NULL,
- theme = theme_minimal()) {
+autoplot.EnsembleFSResult = function(
+ object,
+ type = "pareto",
+ pareto_front = "stepwise",
+ stability_measure = "jaccard",
+ stability_args = NULL,
+ theme = theme_minimal(),
+ ...
+ ) {
assert_string(type)
assert_choice(pareto_front, choices = c("stepwise", "estimated", "none"))
result = object$result
measure_id = object$measure
switch(type,
- "pareto" = {
- p = ggplot(result, mapping = aes(
- x = .data[["n_features"]],
- y = .data[[measure_id]],
- color = learner_id)) +
- geom_point() +
- scale_color_viridis_d("Learner ID", end = 0.8, alpha = 0.8) +
- xlab("Number of Features") +
- ylab(measure_id) +
- theme
+ "pareto" = {
+ p = ggplot(result, mapping = aes(
+ x = .data[["n_features"]],
+ y = .data[[measure_id]],
+ color = .data[["learner_id"]])) +
+ geom_point() +
+ scale_color_viridis_d("Learner ID", end = 0.8, alpha = 0.8) +
+ xlab("Number of Features") +
+ ylab(measure_id) +
+ theme
- if (pareto_front == "stepwise") {
- pf = object$pareto_front(type = "empirical")
- pf_step = stepwise_pf(pf)
- p = p +
- geom_line(data = pf_step, mapping = aes(
- x = .data[["n_features"]],
- y = .data[[measure_id]]),
- color = "black", linewidth = 0.7)
- } else if (pareto_front == "estimated") {
- pfe = object$pareto_front(type = "estimated")
- p = p +
- geom_line(data = pfe, mapping = aes(
- x = .data[["n_features"]],
- y = .data[[measure_id]]),
- color = "black", linetype = "dashed", linewidth = 0.7)
- }
+ if (pareto_front == "stepwise") {
+ pf = object$pareto_front(type = "empirical")
+ pf_step = stepwise_pf(pf)
+ p = p +
+ geom_line(data = pf_step, mapping = aes(
+ x = .data[["n_features"]],
+ y = .data[[measure_id]]),
+ color = "black", linewidth = 0.7)
+ } else if (pareto_front == "estimated") {
+ pfe = object$pareto_front(type = "estimated")
+ p = p +
+ geom_line(data = pfe, mapping = aes(
+ x = .data[["n_features"]],
+ y = .data[[measure_id]]),
+ color = "black", linetype = "dashed", linewidth = 0.7)
+ }
- p
- },
+ p
+ },
- "performance" = {
- ggplot(result, aes(
- x = .data[["learner_id"]],
- y = .data[[measure_id]],
- fill = learner_id)) +
- geom_boxplot(show.legend = FALSE) +
- scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
- ylab(measure_id) +
- theme +
- theme(axis.title.x = element_blank())
- },
+ "performance" = {
+ ggplot(result, aes(
+ x = .data[["learner_id"]],
+ y = .data[[measure_id]],
+ fill = .data[["learner_id"]])) +
+ geom_boxplot(show.legend = FALSE) +
+ scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
+ ylab(measure_id) +
+ theme +
+ theme(axis.title.x = element_blank())
+ },
- "n_features" = {
- ggplot(result, aes(
- x = .data[["learner_id"]],
- y = .data[["n_features"]],
- fill = learner_id)) +
- geom_boxplot(show.legend = FALSE) +
- scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
- ylab("Number of Features") +
- theme +
- theme(axis.title.x = element_blank())
- },
+ "n_features" = {
+ ggplot(result, aes(
+ x = .data[["learner_id"]],
+ y = .data[["n_features"]],
+ fill = .data[["learner_id"]]))+
+ geom_boxplot(show.legend = FALSE) +
+ scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
+ ylab("Number of Features") +
+ theme +
+ theme(axis.title.x = element_blank())
+ },
- "stability" = {
- # get stability per learner
- stab_res = object$stability(
- stability_measure = stability_measure,
- stability_args = stability_args,
- global = FALSE,
- reset_cache = FALSE)
- data = data.table(learner_id = names(stab_res), value = stab_res)
+ "stability" = {
+ # get stability per learner
+ stab_res = object$stability(
+ stability_measure = stability_measure,
+ stability_args = stability_args,
+ global = FALSE,
+ reset_cache = FALSE)
+ data = data.table(learner_id = names(stab_res), value = stab_res)
- ggplot(data, mapping = aes(
- x = .data[["learner_id"]],
- y = .data[["value"]],
- fill = .data[["learner_id"]])) +
- geom_bar(stat = "identity", alpha = 0.8, show.legend = FALSE) +
- scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
- ylab(stability_measure) +
- theme +
- theme(axis.title.x = element_blank())
- },
+ ggplot(data, mapping = aes(
+ x = .data[["learner_id"]],
+ y = .data[["value"]],
+ fill = .data[["learner_id"]])) +
+ geom_bar(stat = "identity", alpha = 0.8, show.legend = FALSE) +
+ scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
+ ylab(stability_measure) +
+ theme +
+ theme(axis.title.x = element_blank())
+ },
- stopf("Unknown plot type '%s'", type)
+ stopf("Unknown plot type '%s'", type)
)
}
diff --git a/man/autoplot.EnsembleFSResult.Rd b/man/autoplot.EnsembleFSResult.Rd
new file mode 100644
index 00000000..c42360bf
--- /dev/null
+++ b/man/autoplot.EnsembleFSResult.Rd
@@ -0,0 +1,96 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/EnsembleFSResult.R
+\name{autoplot.EnsembleFSResult}
+\alias{autoplot.EnsembleFSResult}
+\title{Plots for Ensemble Feature Selection Results}
+\usage{
+\method{autoplot}{EnsembleFSResult}(
+ object,
+ type = "pareto",
+ pareto_front = "stepwise",
+ stability_measure = "jaccard",
+ stability_args = NULL,
+ theme = theme_minimal(),
+ ...
+)
+}
+\arguments{
+\item{object}{(\link[mlr3fselect:ensemble_fs_result]{mlr3fselect::EnsembleFSResult}).}
+
+\item{type}{(character(1)):\cr
+Type of the plot. See description.}
+
+\item{pareto_front}{(\code{character(1)})\cr
+Type of pareto front to plot. Can be \code{"stepwise"} (default), \code{"estimated"}
+or \code{"none"}.}
+
+\item{stability_measure}{(\code{character(1)})\cr
+The stability measure to be used in case \code{type = "stability"}.
+One of the measures returned by \code{\link[stabm:listStabilityMeasures]{stabm::listStabilityMeasures()}} in lower case.
+Default is \code{"jaccard"}.}
+
+\item{stability_args}{(\code{list})\cr
+Additional arguments passed to the stability measure function.}
+
+\item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr
+The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.}
+
+\item{...}{(ignored).}
+}
+\value{
+\code{\link[ggplot2:ggplot]{ggplot2::ggplot()}}.
+}
+\description{
+Visualizations for \link[mlr3fselect:ensemble_fs_result]{EnsembleFSResult}.
+The argument \code{type} determines the type of plot generated.
+The available options are:
+\itemize{
+\item \code{"pareto"} (default): Scatterplot of performance versus the number of
+features, possibly including the \strong{Pareto front}, which allows users to
+decide how much performance they are willing to trade off for a more sparse
+model.
+\item \code{"performance"}: Boxplot of performance across the different learners
+used in the ensemble feature selection process.
+Each box represents the distribution of scores across different resampling
+iterations for a particular learner.
+\item \verb{"n_features}: Boxplot of the number of features selected by each learner
+in the different resampling iterations.
+\item \code{"stability"}: Barplot of stability score for each learner used in the
+ensemble feature selection. This plot shows how similar are the output feature
+sets from each learner across the different resamplings.
+}
+}
+\examples{
+\donttest{
+if (requireNamespace("mlr3")) {
+ library(mlr3)
+ library(mlr3fselect)
+
+ set.seed (42)
+ efsr = ensemble_fselect(
+ fselector = fs("random_search"),
+ task = tsk("sonar"),
+ learners = lrns(c("classif.rpart", "classif.featureless")),
+ init_resampling = rsmp("subsampling", repeats = 5),
+ inner_resampling = rsmp("cv", folds = 3),
+ measure = msr("classif.ce"),
+ terminator = trm("evals", n_evals = 5)
+ )
+
+ # Pareto front (default, stepwise)
+ autoplot(efsr)
+
+ # Pareto front (estimated)
+ autoplot(efsr, pareto_front = "estimated")
+
+ # Performance
+ autoplot(efsr, type = "performance")
+
+ # Number of features
+ autoplot(efsr, type = "n_features")
+
+ # stability
+ autoplot(efsr, type = "stability")
+}
+}
+}
diff --git a/tests/testthat/_snaps/EnsembleFSResult/pareto-estimated.svg b/tests/testthat/_snaps/EnsembleFSResult/pareto-estimated.svg
new file mode 100644
index 00000000..2c84aef6
--- /dev/null
+++ b/tests/testthat/_snaps/EnsembleFSResult/pareto-estimated.svg
@@ -0,0 +1,77 @@
+
+
diff --git a/tests/testthat/_snaps/EnsembleFSResult/pareto-n-features.svg b/tests/testthat/_snaps/EnsembleFSResult/pareto-n-features.svg
new file mode 100644
index 00000000..2792961c
--- /dev/null
+++ b/tests/testthat/_snaps/EnsembleFSResult/pareto-n-features.svg
@@ -0,0 +1,65 @@
+
+
diff --git a/tests/testthat/_snaps/EnsembleFSResult/pareto-performance.svg b/tests/testthat/_snaps/EnsembleFSResult/pareto-performance.svg
new file mode 100644
index 00000000..de80e5cd
--- /dev/null
+++ b/tests/testthat/_snaps/EnsembleFSResult/pareto-performance.svg
@@ -0,0 +1,62 @@
+
+
diff --git a/tests/testthat/_snaps/EnsembleFSResult/pareto-stability.svg b/tests/testthat/_snaps/EnsembleFSResult/pareto-stability.svg
new file mode 100644
index 00000000..e34ae39d
--- /dev/null
+++ b/tests/testthat/_snaps/EnsembleFSResult/pareto-stability.svg
@@ -0,0 +1,53 @@
+
+
diff --git a/tests/testthat/_snaps/EnsembleFSResult/pareto-stepwise.svg b/tests/testthat/_snaps/EnsembleFSResult/pareto-stepwise.svg
new file mode 100644
index 00000000..bcbf74b8
--- /dev/null
+++ b/tests/testthat/_snaps/EnsembleFSResult/pareto-stepwise.svg
@@ -0,0 +1,76 @@
+
+