You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The `explain()` function creates an explainer object for the survival models. The `data` argument specifies the dataset used for explanation, and the `model` argument specifies the model to be explained. The `target` argument indicates the type of prediction to be explained (e.g., "survival", "risk", "cumulative hazard").
79
-
80
-
```{r td explainer}
81
-
exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
82
-
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
83
-
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)
84
-
```
85
-
86
77
## Performance
87
78
88
79
The performance of the models is evaluated using the C-Index and Integrated Brier Score (IBS). The C-Index measures the concordance between predicted and observed survival times, while the IBS quantifies the accuracy of survival predictions.
The `explain()` function creates an explainer object for the survival models. The `data` argument specifies the dataset used for explanation, and the `model` argument specifies the model to be explained. The `target` argument indicates the type of prediction to be explained (e.g., "survival", "risk", "cumulative hazard").
95
+
96
+
```{r td explainer}
97
+
exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
98
+
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
99
+
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)
The survival predictions for the test dataset are computed using the `predict()` function. The `type` argument specifies the type of prediction to be made (e.g., "survival", "risk", "cumulative hazard"). The survival predictions are then plotted for a set of instances of interest.
The following sections demonstrate the application of various gradient-based explanation methods to the survival models. The methods include Grad(t), SmoothGrad(t), IntGrad(t), and GradSHAP(t), corresponding to the plots shown in the main body of the paper.
162
163
163
-
###Grad(t) (Sensitivity)
164
+
## Grad(t) (Sensitivity)
164
165
165
166
Here we compute the gradient of the survival predictions with respect to the input features. The `surv_grad()` function computes the gradients for the specified instances.
SmoothGrad(t) is a method that adds noise to the input features and computes the average gradient over multiple noisy samples. This approach helps to reduce the noise in the gradient estimates and provides a clearer picture of the feature importance.
180
181
@@ -198,13 +199,13 @@ smoothgrad_plot
198
199
199
200
The relevance curves derived from output-sensitive methods effectively reveal the time-dependent effect of $x_1$ on the survival predictions, by indicating a positive effect at earlier times and a negative effect later on. This time-dependent effect is accurately captured by CoxTime and DeepHit, but not by DeepSurv, which is inherently constrained by the PH assumption and thus unable to model time-dependence.
200
201
201
-
###IntGrad(t)
202
+
## IntGrad(t)
202
203
203
204
IntGrad(t) is a method that computes the integral of the gradients along a straight line path from a reference point to the input instance. This method provides a more comprehensive view of the feature importance by considering the cumulative effect of the features over time.
204
205
205
206
In addition to time-dependence in feature effects, difference-to-reference methods (i.e., IntGrad(t) and GradSHAP(t)) provide insights into the relative scale, direction, and magnitude of feature effects by comparing predictions to a meaningful reference.
206
207
207
-
####Zero baseline (should be proportional to Gradient x I(t))
208
+
### Zero baseline
208
209
209
210
The zero baseline is a reference point where all features are set to zero.
210
211
@@ -259,7 +260,7 @@ intgrad0_plot_force
259
260
```
260
261
261
262
262
-
### GradShap(t)
263
+
##GradSHAP(t)
263
264
264
265
GradSHAP(t) is a method that computes the SHAP values for survival predictions. It provides a measure of the contribution of each feature to the survival predictions, taking into account the time-dependent effects.
The `explain()` function creates an explainer object for the survival models. The `data` argument specifies the dataset used for explanation, and the `model` argument specifies the model to be explained. The `target` argument indicates the type of prediction to be explained (e.g., "survival", "risk", "cumulative hazard").
77
-
78
-
79
-
```{r tid explainer}
80
-
exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
81
-
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
82
-
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)
83
-
```
84
-
85
71
## Performance
86
72
87
73
The performance of the models is evaluated using the C-Index and Integrated Brier Score (IBS). The C-Index measures the concordance between predicted and observed survival times, while the IBS quantifies the accuracy of survival predictions.
The `explain()` function creates an explainer object for the survival models. The `data` argument specifies the dataset used for explanation, and the `model` argument specifies the model to be explained. The `target` argument indicates the type of prediction to be explained (e.g., "survival", "risk", "cumulative hazard").
90
+
91
+
92
+
```{r tid explainer}
93
+
exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
94
+
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
95
+
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)
96
+
```
97
+
98
+
## Survival Prediction
101
99
102
100
The survival predictions for the test dataset are computed using the `predict()` function. The `type` argument specifies the type of prediction to be made (e.g., "survival", "risk", "cumulative hazard"). The survival predictions are then plotted for a set of instances of interest.
103
101
@@ -123,12 +121,12 @@ surv_plot
123
121
```
124
122
125
123
126
-
##Explainable AI
124
+
# Explainable AI
127
125
128
126
The following sections demonstrate the application of various gradient-based explanation methods to the survival models. The methods include Grad(t), SmoothGrad(t), G x I(t), SmoothGrad x I(t), IntGrad(t), and GradSHAP(t). Each method provides insights into the contributions of the covariates to the survival predictions.
129
127
130
128
131
-
###Grad(t) (Sensitivity)
129
+
## Grad(t) (Sensitivity)
132
130
133
131
Here we compute the gradient of the survival predictions with respect to the input features. The `surv_grad()` function computes the gradients for the specified instances.
SmoothGrad(t) is a method that adds noise to the input features and computes the average gradient over multiple noisy samples. This approach helps to reduce the noise in the gradient estimates and provides a clearer picture of the feature importance.
148
146
@@ -162,7 +160,7 @@ smoothgrad_plot
162
160
```
163
161
164
162
165
-
###Grad x I(t)
163
+
## Grad x I(t)
166
164
167
165
Grad x I(t) is a method that computes the gradient of the survival predictions with respect to the input features and multiplies it by the survival predictions themselves. This approach provides insights into the true local effects of the covariates on the survival prediction.
SmoothGrad x I(t) is a method that adds noise to the input features and computes the average gradient over multiple noisy samples, multiplied by the survival predictions. This approach helps to reduce the noise in the gradient estimates and provides a clearer picture of the feature importance.
IntGrad(t) is a method that computes the integral of the gradients along a straight line path from a reference point to the input instance. This method provides a more comprehensive view of the feature importance by considering the cumulative effect of the features over time.
218
216
219
-
## Zero baseline (should be proportional to Gradient x I(t))
217
+
###Zero baseline
220
218
221
219
The zero baseline is a reference point where all features are set to zero.
222
220
@@ -271,7 +269,7 @@ intgrad0_plot_force
271
269
```
272
270
273
271
274
-
### GradShap
272
+
##GradSHAP(t)
275
273
276
274
GradSHAP(t) is a method that computes the SHAP values for survival predictions. It provides a measure of the contribution of each feature to the survival predictions, taking into account the time-dependent effects.
0 commit comments