forked from tidymodels/TMwR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
18-explaining-models-and-predictions.Rmd
467 lines (356 loc) · 23.8 KB
/
18-explaining-models-and-predictions.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
```{r explain-setup, include = FALSE}
knitr::opts_chunk$set(fig.path = "figures/")
library(tidymodels)
library(forcats)
tidymodels_prefer()
conflicted::conflict_prefer("vi", "vip")
conflicted::conflict_prefer("explain", "lime")
source("ames_snippets.R")
ames_train <- ames_train %>%
mutate_if(is.integer, as.numeric)
rf_model <-
rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("regression")
rf_wflow <-
workflow() %>%
add_formula(
Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type +
Latitude + Longitude) %>%
add_model(rf_model)
rf_fit <- rf_wflow %>% fit(data = ames_train)
lm_fit <- lm_wflow %>% fit(data = ames_train)
```
# Explaining Models and Predictions {#explain}
In Section \@ref(model-types), we outlined a taxonomy of models and suggested that models typically are built as one or more of descriptive, inferential, or predictive. We suggested that model performance, as measured by appropriate metrics (like RMSE for regression or area under the ROC curve for classification), can be important for all modeling applications. Similarly, model explanations, answering _why_ a model makes the predictions it does, can be important whether the purpose of your model is largely descriptive, to test a hypothesis, or to make a prediction. Answering the question "why?" allows modeling practitioners to understand which features were important in predictions and even how model predictions would change under different values for the features. This chapter covers how to ask a model why it makes the predictions it does.
For some models, like linear regression, it is usually clear how to explain why the model makes its predictions. The structure of a linear model contains coefficients for each predictor that are typically straightforward to interpret. For other models, like random forests that can capture nonlinear behavior by design, it is less transparent how to explain the model's predictions from only the structure of the model itself. Instead, we can apply model explainer algorithms to generate understanding of predictions.
:::rmdnote
There are two types of model explanations, *global* and *local*. Global model explanations provide an overall understanding aggregated over a whole set of observations; local model explanations provide information about a prediction for a single observation.
:::
## Software for Model Explanations
The tidymodels framework does not itself contain software for model explanations. Instead, models trained and evaluated with tidymodels can be explained with other, supplementary software in R packages such as [`r pkg(lime)`](https://lime.data-imaginist.com/), [`r pkg(vip)`](https://koalaverse.github.io/vip/), and [`r pkg(DALEX)`](https://dalex.drwhy.ai/). We often choose:
- `r pkg(vip)` functions when we want to use _model-based_ methods that take advantage of model structure (and are often faster)
- `r pkg(DALEX)` functions when we want to use _model-agnostic_ methods that can be applied to any model
In Chapters \@ref(resampling) and \@ref(compare), we trained and compared several models to predict the price of homes in Ames, IA, including a linear model with interactions and a random forest model, with results shown in Figure \@ref(fig:explain-obs-pred).
```{r explain-obs-pred}
#| echo = FALSE,
#| fig.height = 4,
#| fig.cap = "Comparing predicted prices for a linear model with interactions and a random forest model",
#| fig.alt = "Comparing predicted prices for a linear model with interactions and a random forest model. The random forest results in more accurate predictions."
bind_rows(
augment(lm_fit, ames_train) %>% mutate(model = "lm + interactions"),
augment(rf_fit, ames_train) %>% mutate(model = "random forest")
) %>%
ggplot(aes(Sale_Price, .pred)) +
geom_abline(color = "gray50", lty = 2) +
geom_point(alpha = 0.2, show.legend = FALSE) +
facet_wrap(vars(model)) +
labs(x = "true price", y = "predicted price")
```
Let's build model-agnostic explainers for both of these models to find out why they make these predictions. We can use the `r pkg(DALEXtra)` add-on package for `r pkg(DALEX)`, which provides support for tidymodels. @Biecek2021 provide a thorough exploration of how to use `r pkg(DALEX)` for model explanations; this chapter only summarizes some important approaches, specific to tidymodels. To compute any kind of model explanation, global or local, using `r pkg(DALEX)`, we first prepare the appropriate data and then create an _explainer_ for each model:
```{r explain-explainers, message=FALSE}
library(DALEXtra)
vip_features <- c("Neighborhood", "Gr_Liv_Area", "Year_Built",
"Bldg_Type", "Latitude", "Longitude")
vip_train <-
ames_train %>%
select(all_of(vip_features))
explainer_lm <-
explain_tidymodels(
lm_fit,
data = vip_train,
y = ames_train$Sale_Price,
label = "lm + interactions",
verbose = FALSE
)
explainer_rf <-
explain_tidymodels(
rf_fit,
data = vip_train,
y = ames_train$Sale_Price,
label = "random forest",
verbose = FALSE
)
```
:::rmdwarning
A linear model is typically straightforward to interpret and explain; you may not often find yourself using separate model explanation algorithms for a linear model. However, it can sometimes be difficult to understand or explain the predictions of even a linear model once it has splines and interaction terms!
:::
Dealing with significant feature engineering transformations during model explainability highlights some options we have (or sometimes, ambiguity in such analyses). We can quantify global or local model explanations either in terms of:
- *original, basic predictors* as they existed without significant feature engineering transformations, or
- *derived features*, such as those created via dimensionality reduction (Chapter \@ref(dimensionality)) or interactions and spline terms, as in this example.
## Local Explanations
Local model explanations provide information about a prediction for a single observation. For example, let's consider an older duplex in the North Ames neighborhood (Section \@ref(exploring-features-of-homes-in-ames)):
```{r explain-duplex}
duplex <- vip_train[120,]
duplex
```
There are multiple possible approaches to understanding why a model predicts a given price for this duplex. One is a break-down explanation, implemented with the `r pkg(DALEX)` function `predict_parts()`; it computes how contributions attributed to individual features change the mean model's prediction for a particular observation, like our duplex. For the linear model, the duplex status (`Bldg_Type = 3`),[^dalexlevels] size, longitude, and age all contribute the most to the price being driven down from the intercept:
```{r explain-duplex-lm-breakdown}
lm_breakdown <- predict_parts(explainer = explainer_lm, new_observation = duplex)
lm_breakdown
```
[^dalexlevels]: Notice that this package for model explanations focuses on the *level* of categorical predictors in this type of output, like `Bldg_Type = 3` for duplex and `Neighborhood = 1` for North Ames.
Since this linear model was trained using spline terms for latitude and longitude, the contribution to price for `Longitude` shown here combines the effects of all of its individual spline terms. The contribution is in terms of the original `Longitude` feature, not the derived spline features.
The most important features are slightly different for the random forest model, with the size, age, and duplex status being most important:
```{r explain-duplex-rf-breakdown}
rf_breakdown <- predict_parts(explainer = explainer_rf, new_observation = duplex)
rf_breakdown
```
:::rmdwarning
Model break-down explanations like these depend on the _order_ of the features.
:::
If we choose the `order` for the random forest model explanation to be the same as the default for the linear model (chosen via a heuristic), we can change the relative importance of the features:
```{r explain-duplex-rf-breakdown-reorder}
predict_parts(
explainer = explainer_rf,
new_observation = duplex,
order = lm_breakdown$variable_name
)
```
We can use the fact that these break-down explanations change based on order to compute the most important features over all (or many) possible orderings. This is the idea behind Shapley Additive Explanations [@Lundberg2017], where the average contributions of features are computed under different combinations or "coalitions" of feature orderings. Let's compute SHAP attributions for our duplex, using `B = 20` random orderings:
```{r explain-duplex-rf-shap-calc}
set.seed(1801)
shap_duplex <-
predict_parts(
explainer = explainer_rf,
new_observation = duplex,
type = "shap",
B = 20
)
```
We could use the default plot method from `r pkg(DALEX)` by calling `plot(shap_duplex)`, or we can access the underlying data and create a custom plot. The box plots in Figure \@ref(fig:duplex-rf-shap) display the distribution of contributions across all the orderings we tried, and the bars display the average attribution for each feature:
```{r explain-duplex-rf-shap, eval=FALSE}
library(forcats)
shap_duplex %>%
group_by(variable) %>%
mutate(mean_val = mean(contribution)) %>%
ungroup() %>%
mutate(variable = fct_reorder(variable, abs(mean_val))) %>%
ggplot(aes(contribution, variable, fill = mean_val > 0)) +
geom_col(data = ~distinct(., variable, mean_val),
aes(mean_val, variable),
alpha = 0.5) +
geom_boxplot(width = 0.5) +
theme(legend.position = "none") +
scale_fill_viridis_d() +
labs(y = NULL)
```
```{r duplex-rf-shap, ref.label = "explain-duplex-rf-shap"}
#| echo = FALSE,
#| fig.cap = "Shapley additive explanations from the random forest model for a duplex property",
#| fig.alt = "Shapley additive explanations from the random forest model for a duplex property. Year built and gross living area have the largest contributions."
```
What about a different observation in our data set? Let's look at a larger, newer one-family home in the Gilbert neighborhood:
```{r explain-gilbert}
big_house <- vip_train[1269,]
big_house
```
We can compute SHAP average attributions for this house in the same way:
```{r explain-gilbert-shap-calc}
set.seed(1802)
shap_house <-
predict_parts(
explainer = explainer_rf,
new_observation = big_house,
type = "shap",
B = 20
)
```
The results are shown in Figure \@ref(fig:gilbert-shap); unlike the duplex, the size and age of this house contribute to its price being higher.
```{r gilbert-shap}
#| echo = FALSE,
#| fig.cap = "Shapley additive explanations from the random forest model for a one-family home in Gilbert",
#| fig.alt = "Shapley additive explanations from the random forest model for a one-family home in Gilbert. Gross living area and year built have the largest contributions but in the opposite direction as the previous explainer."
shap_house %>%
group_by(variable) %>%
mutate(mean_val = mean(contribution)) %>%
ungroup() %>%
mutate(variable = fct_reorder(variable, abs(mean_val))) %>%
ggplot(aes(contribution, variable, fill = mean_val > 0)) +
geom_col(data = ~distinct(., variable, mean_val),
aes(mean_val, variable),
alpha = 0.5) +
geom_boxplot(width = 0.5) +
scale_fill_viridis_d() +
theme(legend.position = "none") +
labs(y = NULL)
```
## Global Explanations
Global model explanations, also called global feature importance or variable importance, help us understand which features are most important in driving the predictions of the linear and random forest models overall, aggregated over the whole training set. While the previous section addressed which variables or features are most important in predicting sale price for an individual home, global feature importance addresses the most important variables for a model in aggregate.
:::rmdnote
One way to compute variable importance is to _permute_ the features [@breiman2001random]. We can permute or shuffle the values of a feature, predict from the model, and then measure how much worse the model fits the data compared to before shuffling.
:::
If shuffling a column causes a large degradation in model performance, it is important; if shuffling a column's values doesn't make much difference to how the model performs, it must not be an important variable. This approach can be applied to any kind of model (it is _model agnostic_), and the results are straightforward to understand.
Using `r pkg(DALEX)`, we compute this kind of variable importance via the `model_parts()` function.
```{r explain-global-calcs}
set.seed(1803)
vip_lm <- model_parts(explainer_lm, loss_function = loss_root_mean_square)
set.seed(1804)
vip_rf <- model_parts(explainer_rf, loss_function = loss_root_mean_square)
```
Again, we could use the default plot method from `r pkg(DALEX)` by calling `plot(vip_lm, vip_rf)` but the underlying data is available for exploration, analysis, and plotting. Let's create a function for plotting:
```{r explain-global-func}
ggplot_imp <- function(...) {
obj <- list(...)
metric_name <- attr(obj[[1]], "loss_name")
metric_lab <- paste(metric_name,
"after permutations\n(higher indicates more important)")
full_vip <- bind_rows(obj) %>%
filter(variable != "_baseline_")
perm_vals <- full_vip %>%
filter(variable == "_full_model_") %>%
group_by(label) %>%
summarise(dropout_loss = mean(dropout_loss))
p <- full_vip %>%
filter(variable != "_full_model_") %>%
mutate(variable = fct_reorder(variable, dropout_loss)) %>%
ggplot(aes(dropout_loss, variable))
if(length(obj) > 1) {
p <- p +
facet_wrap(vars(label)) +
geom_vline(data = perm_vals, aes(xintercept = dropout_loss, color = label),
linewidth = 1.4, lty = 2, alpha = 0.7) +
geom_boxplot(aes(color = label, fill = label), alpha = 0.2)
} else {
p <- p +
geom_vline(data = perm_vals, aes(xintercept = dropout_loss),
linewidth = 1.4, lty = 2, alpha = 0.7) +
geom_boxplot(fill = "#91CBD765", alpha = 0.4)
}
p +
theme(legend.position = "none") +
labs(x = metric_lab,
y = NULL, fill = NULL, color = NULL)
}
```
Using `ggplot_imp(vip_lm, vip_rf)` produces Figure \@ref(fig:global-rf).
```{r global-rf}
#| echo = FALSE,
#| fig.width = 8,
#| fig.height = 4.5,
#| fig.cap = "Global explainer for the random forest and linear regression models",
#| fig.alt = "Global explainer for the random forest and linear regression models. For both models, gross living area and year built have the largest contributions, but the linear model uses the neighborhood predictor to a large degree."
ggplot_imp(vip_lm, vip_rf)
```
The dashed line in each panel of Figure \@ref(fig:global-rf) shows the RMSE for the full model, either the linear model or the random forest model. Features farther to the right are more important, because permuting them results in higher RMSE. There is quite a lot of interesting information to learn from this plot; for example, neighborhood is quite important in the linear model with interactions/splines but the second least important feature for the random forest model.
## Building Global Explanations from Local Explanations
So far in this chapter, we have focused on local model explanations for a single observation (via Shapley additive explanations) and global model explanations for a data set as a whole (via permuting features). It is also possible to build global model explanations by aggregating local model explanations, as with _partial dependence profiles_.
:::rmdnote
Partial dependence profiles show how the expected value of a model prediction, like the predicted price of a home in Ames, changes as a function of a feature, like the age or gross living area.
:::
One way to build such a profile is by aggregating or averaging profiles for individual observations. A profile showing how an individual observation's prediction changes as a function of a given feature is called an ICE (individual conditional expectation) profile or a CP (*ceteris paribus*) profile. We can compute such individual profiles (for 500 of the observations in our training set) and then aggregate them using the `r pkg(DALEX)` function `model_profile()`:
```{r explain-year-built-pdp-calc}
set.seed(1805)
pdp_age <- model_profile(explainer_rf, N = 500, variables = "Year_Built")
```
Let's create another function for plotting the underlying data in this object:
```{r explain-pdp-func}
ggplot_pdp <- function(obj, x) {
p <-
as_tibble(obj$agr_profiles) %>%
mutate(`_label_` = stringr::str_remove(`_label_`, "^[^_]*_")) %>%
ggplot(aes(`_x_`, `_yhat_`)) +
geom_line(data = as_tibble(obj$cp_profiles),
aes(x = {{ x }}, group = `_ids_`),
linewidth = 0.5, alpha = 0.05, color = "gray50")
num_colors <- n_distinct(obj$agr_profiles$`_label_`)
if (num_colors > 1) {
p <- p + geom_line(aes(color = `_label_`), linewidth = 1.2, alpha = 0.8)
} else {
p <- p + geom_line(color = "midnightblue", linewidth = 1.2, alpha = 0.8)
}
p
}
```
Using this function generates Figure \@ref(fig:year-built), where we can see the nonlinear behavior of the random forest model.
```{r explain-year-built, eval = FALSE}
ggplot_pdp(pdp_age, Year_Built) +
labs(x = "Year built",
y = "Sale Price (log)",
color = NULL)
```
```{r year-built, ref.label = "explain-year-built"}
#| echo = FALSE,
#| fig.cap = "Partial dependence profiles for the random forest model focusing on the year built predictor",
#| fig.alt = "Partial dependence profiles for the random forest model focusing on the year built predictor. The profiles are mostly relatively constant but then increase linearly around 1950."
```
Sale price for houses built in different years is mostly flat, with a modest rise after about 1960. Partial dependence profiles can be computed for any other feature in the model, and also for groups in the data, such as `Bldg_Type`. Let's use 1,000 observations for these profiles.
```{r explain-build-type-pdp, eval = FALSE}
set.seed(1806)
pdp_liv <- model_profile(explainer_rf, N = 1000,
variables = "Gr_Liv_Area",
groups = "Bldg_Type")
ggplot_pdp(pdp_liv, Gr_Liv_Area) +
scale_x_log10() +
scale_color_brewer(palette = "Dark2") +
labs(x = "Gross living area",
y = "Sale Price (log)",
color = NULL)
```
This code produces Figure \@ref(fig:building-type-profiles), where we see that sale price increases the most between about 1,000 and 3,000 square feet of living area, and that different home types (like single family homes or different types of townhouses) mostly exhibit similar increasing trends in price with more living space.
```{r building-type-profiles, ref.label = "explain-build-type-pdp"}
#| echo = FALSE,
#| fig.cap = "Partial dependence profiles for the random forest model focusing on building types and gross living area",
#| fig.alt = "Partial dependence profiles for the random forest model focusing on building types and gross living area. The building type profiles are, for the most part, parallel over gross living area."
```
We have the option of using `plot(pdp_liv)` for default `r pkg(DALEX)` plots, but since we are making plots with the underlying data here, we can even facet by one of the features to visualize if the predictions change differently and highlighting the imbalance in these subgroups (as shown in Figure \@ref(fig:building-type-facets)).
```{r explain-build-type-facet, eval = FALSE}
as_tibble(pdp_liv$agr_profiles) %>%
mutate(Bldg_Type = stringr::str_remove(`_label_`, "random forest_")) %>%
ggplot(aes(`_x_`, `_yhat_`, color = Bldg_Type)) +
geom_line(data = as_tibble(pdp_liv$cp_profiles),
aes(x = Gr_Liv_Area, group = `_ids_`),
linewidth = 0.5, alpha = 0.1, color = "gray50") +
geom_line(linewidth = 1.2, alpha = 0.8, show.legend = FALSE) +
scale_x_log10() +
facet_wrap(~Bldg_Type) +
scale_color_brewer(palette = "Dark2") +
labs(x = "Gross living area",
y = "Sale Price (log)",
color = NULL)
```
```{r building-type-facets, ref.label = "explain-build-type-facet"}
#| echo = FALSE,
#| fig.cap = "Partial dependence profiles for the random forest model focusing on building types and gross living area using facets",
#| fig.alt = "Partial dependence profiles for the random forest model focusing on building types and gross living area using facets. The building type profiles are, for the most part, parallel over gross living area."
```
There is no one correct approach for building model explanations, and the options outlined in this chapter are not exhaustive. We have highlighted good options for explanations at both the individual and global level, as well as how to bridge from one to the other, and we point you to @Biecek2021 and @Molnar2021 for further reading.
## Back to Beans!
In Chapter \@ref(dimensionality), we discussed how to use dimensionality reduction as a feature engineering or preprocessing step when modeling high-dimensional data. For our example data set of dry bean morphology measures predicting bean type, we saw great results from partial least squares (PLS) dimensionality reduction combined with a regularized discriminant analysis model. Which of those morphological characteristics were _most_ important in the bean type predictions? We can use the same approach outlined throughout this chapter to create a model-agnostic explainer and compute, say, global model explanations via `model_parts()`:
```{r explain-bean-data, echo=FALSE}
library(beans)
set.seed(1601)
bean_split <- initial_split(beans, strata = class, prop = 3/4)
bean_train <- training(bean_split)
bean_test <- testing(bean_split)
load("RData/rda_fit.RData")
```
```{r explain-bea-vip}
set.seed(1807)
vip_beans <-
explain_tidymodels(
rda_wflow_fit,
data = bean_train %>% select(-class),
y = bean_train$class,
label = "RDA",
verbose = FALSE
) %>%
model_parts()
```
Using our previously defined importance plotting function, `ggplot_imp(vip_beans)` produces Figure \@ref(fig:bean-explainer).
```{r bean-explainer}
#| echo = FALSE,
#| fig.width = 8,
#| fig.cap = "Global explainer for the regularized discriminant analysis model on the beans data",
#| fig.alt = "Global explainer for the regularized discriminant analysis model on the beans data. Almost all predictors have a significant contribution with shape factors one and four contributing the most. "
ggplot_imp(vip_beans)
```
:::rmdwarning
The measures of global feature importance that we see in Figure \@ref(fig:bean-explainer) incorporate the effects of all of the PLS components, but in terms of the original variables.
:::
Figure \@ref(fig:bean-explainer) shows us that shape factors are among the most important features for predicting bean type, especially shape factor 4, a measure of solidity that takes into account the area $A$, major axis $L$, and minor axis $l$:
$$\text{SF4} = \frac{A}{\pi(L/2)(l/2)}$$
We can see from Figure \@ref(fig:bean-explainer) that shape factor 1 (the ratio of the major axis to the area), the minor axis length, and roundness are the next most important bean characteristics for predicting bean variety.
## Chapter Summary {#explain-summary}
For some types of models, the answer to why a model made a certain prediction is straightforward, but for other types of models, we must use separate explainer algorithms to understand what features are relatively most important for predictions. You can generate two main kinds of model explanations from a trained model. Global explanations provide information aggregated over an entire data set, while local explanations provide understanding about a model's predictions for a single observation.
Packages such as `r pkg(DALEX)` and its supporting package `r pkg(DALEXtra)`, `r pkg(vip)`, and `r pkg(lime)` can be integrated into a tidymodels analysis to provide these model explainers. Model explanations are just one piece of understanding whether your model is appropriate and effective, along with estimates of model performance; Chapter \@ref(trust) further explores the quality and trustworthiness of predictions.