Skip to content

Commit

Permalink
fix: parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Jun 21, 2024
1 parent dabd06d commit 15f5433
Show file tree
Hide file tree
Showing 8 changed files with 509 additions and 75 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Remotes:
mlr-org/mlr3,
mlr-org/bbotk,
mlr-org/mlr3tuning,
mlr-org/mlr3fselect
mlr-org/mlr3fselect@efs_updates
Config/testthat/edition: 3
Config/testthat/parallel: true
Encoding: UTF-8
Expand Down
153 changes: 79 additions & 74 deletions R/EnsembleFSResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#' @template param_theme
#' @param stability_args (`list`)\cr
#' Additional arguments passed to the stability measure function.
#' @param ... (ignored).
#'
#' @return [ggplot2::ggplot()].
#' @examples
Expand Down Expand Up @@ -67,93 +68,97 @@
#' }
#' }
#' @export
autoplot.EnsembleFSResult = function(object, type = "pareto",
pareto_front = "stepwise",
stability_measure = "jaccard",
stability_args = NULL,
theme = theme_minimal()) {
autoplot.EnsembleFSResult = function(
object,
type = "pareto",
pareto_front = "stepwise",
stability_measure = "jaccard",
stability_args = NULL,
theme = theme_minimal(),
...
) {
assert_string(type)
assert_choice(pareto_front, choices = c("stepwise", "estimated", "none"))
result = object$result
measure_id = object$measure

switch(type,
"pareto" = {
p = ggplot(result, mapping = aes(
x = .data[["n_features"]],
y = .data[[measure_id]],
color = learner_id)) +
geom_point() +
scale_color_viridis_d("Learner ID", end = 0.8, alpha = 0.8) +
xlab("Number of Features") +
ylab(measure_id) +
theme
"pareto" = {
p = ggplot(result, mapping = aes(
x = .data[["n_features"]],
y = .data[[measure_id]],
color = .data[["learner_id"]])) +
geom_point() +
scale_color_viridis_d("Learner ID", end = 0.8, alpha = 0.8) +
xlab("Number of Features") +
ylab(measure_id) +
theme

if (pareto_front == "stepwise") {
pf = object$pareto_front(type = "empirical")
pf_step = stepwise_pf(pf)
p = p +
geom_line(data = pf_step, mapping = aes(
x = .data[["n_features"]],
y = .data[[measure_id]]),
color = "black", linewidth = 0.7)
} else if (pareto_front == "estimated") {
pfe = object$pareto_front(type = "estimated")
p = p +
geom_line(data = pfe, mapping = aes(
x = .data[["n_features"]],
y = .data[[measure_id]]),
color = "black", linetype = "dashed", linewidth = 0.7)
}
if (pareto_front == "stepwise") {
pf = object$pareto_front(type = "empirical")
pf_step = stepwise_pf(pf)
p = p +
geom_line(data = pf_step, mapping = aes(
x = .data[["n_features"]],
y = .data[[measure_id]]),
color = "black", linewidth = 0.7)
} else if (pareto_front == "estimated") {
pfe = object$pareto_front(type = "estimated")
p = p +
geom_line(data = pfe, mapping = aes(
x = .data[["n_features"]],
y = .data[[measure_id]]),
color = "black", linetype = "dashed", linewidth = 0.7)
}

p
},
p
},

"performance" = {
ggplot(result, aes(
x = .data[["learner_id"]],
y = .data[[measure_id]],
fill = learner_id)) +
geom_boxplot(show.legend = FALSE) +
scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
ylab(measure_id) +
theme +
theme(axis.title.x = element_blank())
},
"performance" = {
ggplot(result, aes(
x = .data[["learner_id"]],
y = .data[[measure_id]],
fill = .data[["learner_id"]])) +
geom_boxplot(show.legend = FALSE) +
scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
ylab(measure_id) +
theme +
theme(axis.title.x = element_blank())
},

"n_features" = {
ggplot(result, aes(
x = .data[["learner_id"]],
y = .data[["n_features"]],
fill = learner_id)) +
geom_boxplot(show.legend = FALSE) +
scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
ylab("Number of Features") +
theme +
theme(axis.title.x = element_blank())
},
"n_features" = {
ggplot(result, aes(
x = .data[["learner_id"]],
y = .data[["n_features"]],
fill = .data[["learner_id"]]))+
geom_boxplot(show.legend = FALSE) +
scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
ylab("Number of Features") +
theme +
theme(axis.title.x = element_blank())
},

"stability" = {
# get stability per learner
stab_res = object$stability(
stability_measure = stability_measure,
stability_args = stability_args,
global = FALSE,
reset_cache = FALSE)
data = data.table(learner_id = names(stab_res), value = stab_res)
"stability" = {
# get stability per learner
stab_res = object$stability(
stability_measure = stability_measure,
stability_args = stability_args,
global = FALSE,
reset_cache = FALSE)
data = data.table(learner_id = names(stab_res), value = stab_res)

ggplot(data, mapping = aes(
x = .data[["learner_id"]],
y = .data[["value"]],
fill = .data[["learner_id"]])) +
geom_bar(stat = "identity", alpha = 0.8, show.legend = FALSE) +
scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
ylab(stability_measure) +
theme +
theme(axis.title.x = element_blank())
},
ggplot(data, mapping = aes(
x = .data[["learner_id"]],
y = .data[["value"]],
fill = .data[["learner_id"]])) +
geom_bar(stat = "identity", alpha = 0.8, show.legend = FALSE) +
scale_fill_viridis_d(end = 0.8, alpha = 0.8) +
ylab(stability_measure) +
theme +
theme(axis.title.x = element_blank())
},

stopf("Unknown plot type '%s'", type)
stopf("Unknown plot type '%s'", type)
)
}

Expand Down
96 changes: 96 additions & 0 deletions man/autoplot.EnsembleFSResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

77 changes: 77 additions & 0 deletions tests/testthat/_snaps/EnsembleFSResult/pareto-estimated.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 15f5433

Please sign in to comment.