Skip to content

Second event level in weights autoplot #228

@cgoo4

Description

@cgoo4

Simon hi - I wanted to label autoplot("weights") with the names of the engines. Use of stacks:::top_coefs() was very handy for getting the names!

I was curious though why the weights plot then shows the coefficients for the second factor level rather than the first?

(Reprex based on the vignette example.)

library(tidymodels)
library(stacks)
library(bonsai)

data("tree_frogs")

tree_frogs <- tree_frogs |> 
  select(-c(clutch, latency)) |> 
  mutate(reflex = forcats::fct_other(reflex, keep = "low"))

# First event level = "low"
levels(tree_frogs$reflex)
#> [1] "low"   "Other"

set.seed(1)

tree_frogs_split <- initial_split(tree_frogs)
tree_frogs_train <- training(tree_frogs_split)
tree_frogs_test  <- testing(tree_frogs_split)

folds <- vfold_cv(tree_frogs_train, v = 5)

tree_frogs_rec <- 
  recipe(reflex ~ ., data = tree_frogs_train) |> 
  step_dummy(all_nominal_predictors(), -reflex) |> 
  step_zv(all_predictors())

tree_frogs_wflow <- 
  workflow() |> 
  add_recipe(tree_frogs_rec)

ctrl_grid <- control_stack_grid()

rand_forest_spec <- 
  rand_forest(
    mtry = tune(),
    min_n = tune(),
    trees = 5
  ) |> 
  set_mode("classification") |> 
  set_engine("ranger")

rand_forest_wflow <-
  tree_frogs_wflow |> 
  add_model(rand_forest_spec)

rand_forest_res <- 
  tune_grid(
    object = rand_forest_wflow, 
    resamples = folds, 
    grid = 10,
    control = ctrl_grid
  )
#> i Creating pre-processing data to finalize unknown parameter: mtry

aorsf_spec <- 
  rand_forest(
    mtry = tune(),
    min_n = tune(),
    trees = 500
  ) |> 
  set_mode("classification") |> 
  set_engine("aorsf")

aorsf_wflow <- 
  tree_frogs_wflow |> 
  add_model(aorsf_spec)

aorsf_res <-
  tune_grid(
    object = aorsf_wflow, 
    resamples = folds, 
    grid = 10,
    control = ctrl_grid
  )
#> i Creating pre-processing data to finalize unknown parameter: mtry

tree_frogs_model_st <- 
  stacks() |> 
  add_candidates(rand_forest_res, "ranger") |> 
  add_candidates(aorsf_res, "aorsf") |> 
  blend_predictions() |> 
  fit_members()

# To label the engines
members <- 
  stacks:::top_coefs(
    tree_frogs_model_st, 
    penalty = tree_frogs_model_st$penalty$penalty, 
    n = Inf
    ) |> 
  mutate(member = stringr::str_extract(member, "(?<=_).*(?=_1)"))

# Doesn't show the first event level "low"?
autoplot(tree_frogs_model_st, "weights") + 
  geom_label(
    aes(weight, nrow(members):1, label = member, fill = NULL),
    show.legend = FALSE, hjust = "inward",
    data = members
  )

Created on 2024-09-18 with reprex v2.1.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions