-
Notifications
You must be signed in to change notification settings - Fork 29
Open
Description
Simon hi - I wanted to label autoplot("weights") with the names of the engines. Use of stacks:::top_coefs() was very handy for getting the names!
I was curious though why the weights plot then shows the coefficients for the second factor level rather than the first?
(Reprex based on the vignette example.)
library(tidymodels)
library(stacks)
library(bonsai)
data("tree_frogs")
tree_frogs <- tree_frogs |>
select(-c(clutch, latency)) |>
mutate(reflex = forcats::fct_other(reflex, keep = "low"))
# First event level = "low"
levels(tree_frogs$reflex)
#> [1] "low" "Other"
set.seed(1)
tree_frogs_split <- initial_split(tree_frogs)
tree_frogs_train <- training(tree_frogs_split)
tree_frogs_test <- testing(tree_frogs_split)
folds <- vfold_cv(tree_frogs_train, v = 5)
tree_frogs_rec <-
recipe(reflex ~ ., data = tree_frogs_train) |>
step_dummy(all_nominal_predictors(), -reflex) |>
step_zv(all_predictors())
tree_frogs_wflow <-
workflow() |>
add_recipe(tree_frogs_rec)
ctrl_grid <- control_stack_grid()
rand_forest_spec <-
rand_forest(
mtry = tune(),
min_n = tune(),
trees = 5
) |>
set_mode("classification") |>
set_engine("ranger")
rand_forest_wflow <-
tree_frogs_wflow |>
add_model(rand_forest_spec)
rand_forest_res <-
tune_grid(
object = rand_forest_wflow,
resamples = folds,
grid = 10,
control = ctrl_grid
)
#> i Creating pre-processing data to finalize unknown parameter: mtry
aorsf_spec <-
rand_forest(
mtry = tune(),
min_n = tune(),
trees = 500
) |>
set_mode("classification") |>
set_engine("aorsf")
aorsf_wflow <-
tree_frogs_wflow |>
add_model(aorsf_spec)
aorsf_res <-
tune_grid(
object = aorsf_wflow,
resamples = folds,
grid = 10,
control = ctrl_grid
)
#> i Creating pre-processing data to finalize unknown parameter: mtry
tree_frogs_model_st <-
stacks() |>
add_candidates(rand_forest_res, "ranger") |>
add_candidates(aorsf_res, "aorsf") |>
blend_predictions() |>
fit_members()
# To label the engines
members <-
stacks:::top_coefs(
tree_frogs_model_st,
penalty = tree_frogs_model_st$penalty$penalty,
n = Inf
) |>
mutate(member = stringr::str_extract(member, "(?<=_).*(?=_1)"))
# Doesn't show the first event level "low"?
autoplot(tree_frogs_model_st, "weights") +
geom_label(
aes(weight, nrow(members):1, label = member, fill = NULL),
show.legend = FALSE, hjust = "inward",
data = members
)Created on 2024-09-18 with reprex v2.1.1
Metadata
Metadata
Assignees
Labels
No labels
