Skip to content

Commit 5164dd3

Browse files
committed
fix typos and package version
1 parent d45ae7c commit 5164dd3

File tree

5 files changed

+49
-45
lines changed

5 files changed

+49
-45
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: Survinng
22
Title: Gradient-Based Feature Attribution for Survival Neural Networks
3-
Version: 0.0.1
3+
Version: 0.1.0
44
Authors@R: c(
55
person("Niklas", "Koenen", , "[email protected]", role = c("aut", "cre"),
66
comment = c(ORCID = "0000-0002-4623-8271")),

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Survinng 0.0.1
1+
# Survinng 0.1.0
22

33
* Initial release as part of the ICML'25 paper *Gradient-based
44
Explanations for Deep Learning Survival Models*

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ shap <- surv_gradSHAP(explainer)
116116
plot(shap)
117117
```
118118

119+
## 🖥 Other Examples and Articles
120+
121+
- Simulation: Time-independent effects (`survivalmodels`) [→ article](https://bips-hb.github.io/Survinng/articles/Sim_time_independent.html)
122+
- Simulation: Time-dependent effects (`survivalmodels`) [→ article](https://bips-hb.github.io/Survinng/articles/Sim_time_dependent.html)
123+
119124
## 📚 Citation
120125

121126
If you use this package in your research, please cite it as follows:

vignettes/articles/Sim_time_dependent.Rmd

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ library(viridis)
3131
library(here)
3232
```
3333

34+
# Preprocessing
3435

3536
## Generate the Data
3637

@@ -73,21 +74,11 @@ ext_deephit <- readRDS(here("vignettes/articles/Sim_time_dependent/extracted_mod
7374
```
7475

7576

76-
## Create Explainer
77-
78-
The `explain()` function creates an explainer object for the survival models. The `data` argument specifies the dataset used for explanation, and the `model` argument specifies the model to be explained. The `target` argument indicates the type of prediction to be explained (e.g., "survival", "risk", "cumulative hazard").
79-
80-
```{r td explainer}
81-
exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
82-
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
83-
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)
84-
```
85-
8677
## Performance
8778

8879
The performance of the models is evaluated using the C-Index and Integrated Brier Score (IBS). The C-Index measures the concordance between predicted and observed survival times, while the IBS quantifies the accuracy of survival predictions.
8980

90-
|Model | C-Index | IBS
81+
|Model | C-Index | IBS |
9182
|:---------:|:--------:|:--------:|
9283
|CoxTime | 0.845570 | 0.058430 |
9384
|DeepSurv | 0.859624 | 0.060411 |
@@ -98,6 +89,16 @@ knitr::include_graphics(here('vignettes/articles/Sim_time_dependent/sim_td_brier
9889
```
9990

10091

92+
## Create Explainer
93+
94+
The `explain()` function creates an explainer object for the survival models. The `data` argument specifies the dataset used for explanation, and the `model` argument specifies the model to be explained. The `target` argument indicates the type of prediction to be explained (e.g., "survival", "risk", "cumulative hazard").
95+
96+
```{r td explainer}
97+
exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
98+
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
99+
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)
100+
```
101+
101102

102103
## Kaplan-Meier Survival Curves
103104

@@ -131,7 +132,7 @@ km_plot$plot <- km_plot$plot +
131132
km_plot
132133
```
133134

134-
### Survival Prediction
135+
## Survival Prediction
135136

136137
The survival predictions for the test dataset are computed using the `predict()` function. The `type` argument specifies the type of prediction to be made (e.g., "survival", "risk", "cumulative hazard"). The survival predictions are then plotted for a set of instances of interest.
137138

@@ -156,11 +157,11 @@ surv_plot <- cowplot::plot_grid(
156157
surv_plot
157158
```
158159

159-
## Explainable AI
160+
# Explainable AI
160161

161162
The following sections demonstrate the application of various gradient-based explanation methods to the survival models. The methods include Grad(t), SmoothGrad(t), IntGrad(t), and GradSHAP(t), corresponding to the plots shown in the main body of the paper.
162163

163-
### Grad(t) (Sensitivity)
164+
## Grad(t) (Sensitivity)
164165

165166
Here we compute the gradient of the survival predictions with respect to the input features. The `surv_grad()` function computes the gradients for the specified instances.
166167

@@ -174,7 +175,7 @@ grad_plot <- cowplot::plot_grid(
174175
grad_plot
175176
```
176177

177-
### SmoothGrad(t) (Sensitivity)
178+
## SmoothGrad(t) (Sensitivity)
178179

179180
SmoothGrad(t) is a method that adds noise to the input features and computes the average gradient over multiple noisy samples. This approach helps to reduce the noise in the gradient estimates and provides a clearer picture of the feature importance.
180181

@@ -198,13 +199,13 @@ smoothgrad_plot
198199

199200
The relevance curves derived from output-sensitive methods effectively reveal the time-dependent effect of $x_1$ on the survival predictions, by indicating a positive effect at earlier times and a negative effect later on. This time-dependent effect is accurately captured by CoxTime and DeepHit, but not by DeepSurv, which is inherently constrained by the PH assumption and thus unable to model time-dependence.
200201

201-
### IntGrad(t)
202+
## IntGrad(t)
202203

203204
IntGrad(t) is a method that computes the integral of the gradients along a straight line path from a reference point to the input instance. This method provides a more comprehensive view of the feature importance by considering the cumulative effect of the features over time.
204205

205206
In addition to time-dependence in feature effects, difference-to-reference methods (i.e., IntGrad(t) and GradSHAP(t)) provide insights into the relative scale, direction, and magnitude of feature effects by comparing predictions to a meaningful reference.
206207

207-
#### Zero baseline (should be proportional to Gradient x I(t))
208+
### Zero baseline
208209

209210
The zero baseline is a reference point where all features are set to zero.
210211

@@ -259,7 +260,7 @@ intgrad0_plot_force
259260
```
260261

261262

262-
### GradShap(t)
263+
## GradSHAP(t)
263264

264265
GradSHAP(t) is a method that computes the SHAP values for survival predictions. It provides a measure of the contribution of each feature to the survival predictions, taking into account the time-dependent effects.
265266

vignettes/articles/Sim_time_independent.Rmd

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ library(viridis)
3030
library(here)
3131
```
3232

33-
34-
### Time-independent effects
33+
# Preprocessing
3534

3635
## Generate the data
3736

@@ -69,24 +68,11 @@ ext_deepsurv <- readRDS(here("vignettes/articles/Sim_time_independent/ext_deepsu
6968
ext_deephit <- readRDS(here("vignettes/articles/Sim_time_independent/ext_deephit.rds"))
7069
```
7170

72-
73-
74-
## Create Explainer
75-
76-
The `explain()` function creates an explainer object for the survival models. The `data` argument specifies the dataset used for explanation, and the `model` argument specifies the model to be explained. The `target` argument indicates the type of prediction to be explained (e.g., "survival", "risk", "cumulative hazard").
77-
78-
79-
```{r tid explainer}
80-
exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
81-
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
82-
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)
83-
```
84-
8571
## Performance
8672

8773
The performance of the models is evaluated using the C-Index and Integrated Brier Score (IBS). The C-Index measures the concordance between predicted and observed survival times, while the IBS quantifies the accuracy of survival predictions.
8874

89-
|Model | C-Index | IBS
75+
|Model | C-Index | IBS |
9076
|:---------:|:--------:|:--------:|
9177
|CoxTime | 0.809372 | 0.099053 |
9278
|DeepSurv | 0.809121 | 0.099031 |
@@ -97,7 +83,19 @@ knitr::include_graphics(here('vignettes/articles/Sim_time_independent/sim_tid_br
9783
```
9884

9985

100-
### Survival Prediction
86+
87+
## Create Explainer
88+
89+
The `explain()` function creates an explainer object for the survival models. The `data` argument specifies the dataset used for explanation, and the `model` argument specifies the model to be explained. The `target` argument indicates the type of prediction to be explained (e.g., "survival", "risk", "cumulative hazard").
90+
91+
92+
```{r tid explainer}
93+
exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
94+
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
95+
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)
96+
```
97+
98+
## Survival Prediction
10199

102100
The survival predictions for the test dataset are computed using the `predict()` function. The `type` argument specifies the type of prediction to be made (e.g., "survival", "risk", "cumulative hazard"). The survival predictions are then plotted for a set of instances of interest.
103101

@@ -123,12 +121,12 @@ surv_plot
123121
```
124122

125123

126-
## Explainable AI
124+
# Explainable AI
127125

128126
The following sections demonstrate the application of various gradient-based explanation methods to the survival models. The methods include Grad(t), SmoothGrad(t), G x I(t), SmoothGrad x I(t), IntGrad(t), and GradSHAP(t). Each method provides insights into the contributions of the covariates to the survival predictions.
129127

130128

131-
### Grad(t) (Sensitivity)
129+
## Grad(t) (Sensitivity)
132130

133131
Here we compute the gradient of the survival predictions with respect to the input features. The `surv_grad()` function computes the gradients for the specified instances.
134132

@@ -142,7 +140,7 @@ grad_plot <- cowplot::plot_grid(
142140
grad_plot
143141
```
144142

145-
### SmoothGrad(t) (Sensitivity)
143+
## SmoothGrad(t) (Sensitivity)
146144

147145
SmoothGrad(t) is a method that adds noise to the input features and computes the average gradient over multiple noisy samples. This approach helps to reduce the noise in the gradient estimates and provides a clearer picture of the feature importance.
148146

@@ -162,7 +160,7 @@ smoothgrad_plot
162160
```
163161

164162

165-
### Grad x I(t)
163+
## Grad x I(t)
166164

167165
Grad x I(t) is a method that computes the gradient of the survival predictions with respect to the input features and multiplies it by the survival predictions themselves. This approach provides insights into the true local effects of the covariates on the survival prediction.
168166

@@ -190,7 +188,7 @@ grad_gradin_plot <- cowplot::plot_grid(
190188
grad_gradin_plot
191189
```
192190

193-
### SmoothGrad x I(t)
191+
## SmoothGrad x I(t)
194192

195193
SmoothGrad x I(t) is a method that adds noise to the input features and computes the average gradient over multiple noisy samples, multiplied by the survival predictions. This approach helps to reduce the noise in the gradient estimates and provides a clearer picture of the feature importance.
196194

@@ -212,11 +210,11 @@ smoothgradin_plot <- cowplot::plot_grid(
212210
smoothgradin_plot
213211
```
214212

215-
### IntGrad(t)
213+
## IntGrad(t)
216214

217215
IntGrad(t) is a method that computes the integral of the gradients along a straight line path from a reference point to the input instance. This method provides a more comprehensive view of the feature importance by considering the cumulative effect of the features over time.
218216

219-
## Zero baseline (should be proportional to Gradient x I(t))
217+
### Zero baseline
220218

221219
The zero baseline is a reference point where all features are set to zero.
222220

@@ -271,7 +269,7 @@ intgrad0_plot_force
271269
```
272270

273271

274-
### GradShap
272+
## GradSHAP(t)
275273

276274
GradSHAP(t) is a method that computes the SHAP values for survival predictions. It provides a measure of the contribution of each feature to the survival predictions, taking into account the time-dependent effects.
277275

0 commit comments

Comments
 (0)