Skip to content

Error stacking classification models with pr_auc #225

@cgoo4

Description

@cgoo4

Hi - Stacks is proving very effective at improving model performance for me (using the output from finetune::tune_sim_anneal()). Thank you!

I am occasionally running into this problem though stacking classification models with metric pr_auc.

It's a bit of a contrived example to attempt to recreate the error:

(I also seem to get it when num_members > 0.)

library(tidymodels)
library(stacks)

data("tree_frogs")

tree_frogs <- tree_frogs |> 
  select(-c(clutch, latency)) |> 
  mutate(reflex = if_else(row_number() < 20, "low", "other"))

set.seed(1)

tree_frogs_split <- initial_split(tree_frogs)
tree_frogs_train <- training(tree_frogs_split)
tree_frogs_test  <- testing(tree_frogs_split)

folds <- vfold_cv(tree_frogs_train, v = 5)

tree_frogs_rec <- 
  recipe(reflex ~ ., data = tree_frogs_train) |> 
  step_dummy(all_nominal_predictors(), -reflex) |> 
  step_zv(all_predictors())

tree_frogs_wflow <- 
  workflow() |> 
  add_recipe(tree_frogs_rec)

ctrl_grid <- control_stack_grid()

rand_forest_spec <- 
  rand_forest(
    mtry = tune(),
    min_n = tune(),
    trees = 500
  ) %>%
  set_mode("classification") |> 
  set_engine("ranger")

rand_forest_wflow <-
  tree_frogs_wflow |> 
  add_model(rand_forest_spec)

rand_forest_res <- 
  tune_grid(
    object = rand_forest_wflow, 
    resamples = folds, 
    grid = 10,
    control = ctrl_grid
  )
#> i Creating pre-processing data to finalize unknown parameter: mtry

nnet_spec <-
  mlp(hidden_units = tune(), penalty = tune(), epochs = tune()) |> 
  set_mode("classification") |> 
  set_engine("nnet")

nnet_rec <- 
  tree_frogs_rec |> 
  step_normalize(all_predictors())

nnet_wflow <- 
  tree_frogs_wflow |> 
  add_model(nnet_spec) |> 
  update_recipe(nnet_rec)

nnet_res <-
  tune_grid(
    object = nnet_wflow, 
    resamples = folds, 
    grid = 10,
    control = ctrl_grid
  )

tree_frogs_model_st <- 
  stacks() |> 
  add_candidates(rand_forest_res) |> 
  add_candidates(nnet_res) |> 
  blend_predictions(
    metric = metric_set(pr_auc),
    mixture = 1,
    penalty = seq(0, 1, 0.1),
    ) |> 
  fit_members()
#> → A | warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> There were issues with some computations   A: x1
#> There were issues with some computations   A: x3
#> 

autoplot(tree_frogs_model_st)
#> Warning in ggplot2::scale_x_log10(): log-10 transformation introduced infinite values.
#> log-10 transformation introduced infinite values.

length(tree_frogs_model_st$member_fits)
#> [1] 0

tree_frogs_model_st |> 
  augment(tree_frogs_test, type = "prob")
#> Error in `tidyr::pivot_wider()`:
#> ! Can't select columns past the end.
#> ℹ Location 3 doesn't exist.
#> ℹ There are only 2 columns.

Created on 2024-08-17 with reprex v2.1.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions