Skip to content

Commit fa13a88

Browse files
committed
add real data utils
1 parent 95786c0 commit fa13a88

File tree

1 file changed

+367
-0
lines changed
  • vignettes/articles/real_data

1 file changed

+367
-0
lines changed
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
2+
3+
################################################################################
4+
# R torch Helper functions for the multi-modal model
5+
################################################################################
6+
7+
# Define R torch version of the multi-modal model ------------------------------
8+
9+
# Define multi modal model (in torch we need to re-build the model because
10+
# only the state dict is saved)
11+
MultiModalModel <- nn_module(
12+
"MultiModalModel",
13+
initialize = function(net_images, tabular_features, n_out = 1, n_img_out = 64,
14+
include_time = FALSE, out_bias = FALSE) {
15+
self$net_images <- net_images
16+
17+
input_dim <- n_img_out + length(tabular_features) + ifelse(include_time, 1, 0)
18+
self$fc1 <- nn_linear(input_dim, 256)
19+
self$fc2 <- nn_linear(256, 128)
20+
self$relu <- nn_relu()
21+
self$drop_0_3 <- nn_dropout(0.3)
22+
self$drop_0_4 <- nn_dropout(0.3)
23+
self$out <- nn_linear(128, n_out, bias = out_bias)
24+
},
25+
26+
forward = function(input, time = NULL) {
27+
img <- input[[1]]
28+
tab <- input[[2]]
29+
img <- self$net_images(img)
30+
img <- self$drop_0_4(img)
31+
32+
if (!is.null(time)) {
33+
x <- torch_cat(list(img, tab, time), dim = 2)
34+
} else {
35+
x <- torch_cat(list(img, tab), dim = 2)
36+
}
37+
38+
x <- self$drop_0_3(x)
39+
x <- self$relu(self$fc1(x))
40+
x <- self$drop_0_3(x)
41+
x <- self$relu(self$fc2(x))
42+
x <- self$out(x)
43+
44+
x
45+
}
46+
)
47+
48+
################################################################################
49+
# SurvSHAP utility functions
50+
################################################################################
51+
52+
# Preprocess function for the torch-based multi-modal model
53+
predict_survival_function <- function(model, newdata, times) {
54+
input <- torch_tensor(as.matrix(newdata))
55+
input_img <- input[, seq(1, dim(input)[2] - 4)]$view(c(-1, 3, 32, 32))
56+
input_tab <- input[, seq(dim(input)[2] - 3, dim(input)[2])]
57+
input <- list(input_img, input_tab)
58+
input <- model$preprocess_fun(input)
59+
out <- model(input, target = "survival")[[1]]
60+
out <- out$squeeze(dim = seq_len(out$dim())[-c(1, out$dim())])
61+
as.array(out)
62+
}
63+
64+
# Predict function for the torch-based multi-modal model
65+
predict_function <- function(model, newdata) {
66+
input <- torch_tensor(as.matrix(newdata))
67+
input_img <- input[, seq(1, dim(input)[2] - 4)]$view(c(-1, 3, 32, 32))
68+
input_tab <- input[, seq(dim(input)[2] - 3, dim(input)[2])]
69+
input <- list(input_img, input_tab)
70+
input <- model$preprocess_fun(input)
71+
out <- model(input, target = "survival")[[1]]
72+
out <- out$squeeze(dim = seq_len(out$dim())[-c(1, out$dim())])
73+
as.array(out)
74+
}
75+
76+
################################################################################
77+
# Plotting functions
78+
################################################################################
79+
80+
# Create force plot
81+
plot_force_real <- function(x, num_samples = 10, zero_feature = "Sex (male)", label = "") {
82+
# Convert input to a data frame and calculate derived columns
83+
dat <- as.data.frame(x)
84+
dat$id <- as.factor(dat$id)
85+
dat$feature <- factor(
86+
dat$feature,
87+
levels = c(
88+
"IDH Mutation (mutant)",
89+
"Image (sum)",
90+
"1p19q Codeletion (no)",
91+
"Age (43)",
92+
"Sex (male)"
93+
)
94+
)
95+
96+
# Process data to compute sum of attributions
97+
dat <- dat %>%
98+
group_by(id, time) %>%
99+
mutate(sum = sum(value))
100+
101+
# Sample time points for visualization
102+
t_interest <- sort(unique(dat$time))
103+
target_points <- seq(min(t_interest), max(t_interest), length.out = num_samples)
104+
selected_points <- sapply(target_points, function(x)
105+
t_interest[which.min(abs(t_interest - x))])
106+
dat_small <- dat[dat$time %in% selected_points, ]
107+
108+
109+
# Create position variable for plotting attribution values
110+
dat_small <- dat_small %>%
111+
group_by(id, time) %>%
112+
mutate(
113+
pos = case_when(
114+
feature == "Sex (male)" ~ NA,
115+
feature == "Age (43)" ~ value + value[feature == "Sex (male)"],
116+
feature == "1p19q Codeletion (no)" ~ value,
117+
feature == "IDH Mutation (mutant)" ~ value + value[feature == "Age (43)"] + value[feature == "Sex (male)"],
118+
#feature == "IDH Mutation (mutant)" &
119+
# sign(value[feature == "Image (sum)"]) < 0 ~ value + value[feature == "Age (43)"] + value[feature == "Sex (male)"],
120+
#feature == "Image (sum)" &
121+
# sign(value) < 0 ~ value,
122+
feature == "Image (sum)" ~ value + value[feature == "1p19q Codeletion (no)"],
123+
TRUE ~ NA_real_
124+
)
125+
) %>%
126+
ungroup()
127+
128+
# Additional position variable for the arrows
129+
dat_small$pos_a <- ifelse(dat_small$pos > 0, dat_small$pos + 0.03, dat_small$pos)
130+
dat_small$pos_a <- ifelse(dat_small$pos_a < 0, dat_small$pos_a - 0.03, dat_small$pos_a)
131+
132+
# Adjust label position manually
133+
# dat_small[(round(dat_small$time, 2) == 15.21) &
134+
# (dat_small$feature == "Image (sum)"), "pos"] <- -0.011
135+
dat_small[(round(dat_small$value, 2) == 0.01), "pos"] <- 0.018
136+
137+
# Plot
138+
p <- ggplot() +
139+
geom_bar(
140+
data = dat_small,
141+
mapping = aes(
142+
x = .data$time,
143+
y = .data$value,
144+
fill = .data$feature,
145+
color = .data$feature
146+
),
147+
stat = "identity",
148+
position = "stack"
149+
) +
150+
scale_color_viridis_d(name = "Feature", guide = guide_legend(override.aes = list(linewidth = 7))) +
151+
scale_fill_viridis_d(alpha = 0.4, name = "Feature") +
152+
geom_segment(
153+
data = dat_small[(dat_small$feature != zero_feature) &
154+
(round(dat_small$value, 2) != 0), ],
155+
mapping = aes(
156+
x = .data$time,
157+
xend = .data$time,
158+
y = .data$pos_a,
159+
yend = .data$pos_a + (.data$value) * 0.01,
160+
color = .data$feature
161+
),
162+
arrow = arrow(type = "closed", length = unit(0.1, "inches")),
163+
linewidth = 6
164+
) +
165+
geom_label(
166+
data = dat_small[round(dat_small$value, 2) != 0, ],
167+
mapping = aes(
168+
x = .data$time,
169+
y = .data$pos,
170+
label = round(.data$value, 2)
171+
),
172+
color = "black",
173+
size = 3,
174+
vjust = 0.5,
175+
hjust = 0.5,
176+
na.rm = TRUE
177+
) +
178+
geom_line(
179+
data = dat,
180+
mapping = aes(x = .data$time, y = .data$sum),
181+
color = "black",
182+
linewidth = 1.5
183+
) +
184+
facet_wrap(vars(.data$id),
185+
scales = "free_x",
186+
labeller = as_labeller(function(a)
187+
paste0("Instance ID: ", a))) +
188+
theme_minimal(base_size = 13) +
189+
theme(legend.position = "bottom") +
190+
ylim(-0.15, 0.16) +
191+
scale_x_continuous(expand = c(0,0)) +
192+
labs(
193+
x = "Time",
194+
y = paste0("Force Plot: ", label),
195+
color = "Feature",
196+
fill = "Feature"
197+
)
198+
199+
return(p)
200+
}
201+
202+
# Plot function
203+
plot_bars_real <- function(x, num_samples = 40, zero_feature = "Sex (male)", label = "") {
204+
# Convert input to a data frame and calculate derived columns
205+
dat <- as.data.frame(x)
206+
dat$id <- as.factor(dat$id)
207+
dat$feature <- factor(
208+
dat$feature,
209+
levels = c(
210+
"IDH Mutation (mutant)",
211+
"Image (sum)",
212+
"1p19q Codeletion (no)",
213+
"Age (43)",
214+
"Sex (male)"
215+
)
216+
)
217+
218+
# Process data to compute sum of attributions
219+
dat <- dat %>%
220+
group_by(id, time) %>%
221+
mutate(sum = sum(value))
222+
223+
# Sample time points for visualization
224+
t_interest <- sort(unique(dat$time))
225+
target_points <- seq(min(t_interest), max(t_interest), length.out = num_samples)
226+
selected_points <- sapply(target_points, function(x)
227+
t_interest[which.min(abs(t_interest - x))])
228+
dat_small <- dat[dat$time %in% selected_points, ]
229+
230+
# Plot
231+
p <- ggplot() +
232+
geom_line(
233+
data = dat,
234+
mapping = aes(x = .data$time, y = .data$sum),
235+
color = "black",
236+
linewidth = 1
237+
) +
238+
geom_bar(
239+
data = dat_small,
240+
mapping = aes(
241+
x = .data$time,
242+
y = .data$value,
243+
fill = .data$feature,
244+
color = .data$feature
245+
),
246+
stat = "identity",
247+
position = "stack"
248+
) +
249+
scale_color_viridis_d(name = "Feature") +
250+
scale_fill_viridis_d(alpha = 0.4, name = "Feature") +
251+
facet_wrap(vars(.data$id),
252+
scales = "free_x",
253+
labeller = as_labeller(function(a)
254+
paste0("Instance ID: ", a))) +
255+
theme_minimal(base_size = 15) +
256+
theme(legend.position = "bottom") +
257+
labs(
258+
x = "Time",
259+
y = paste0("Contribution: ", label),
260+
color = "Feature",
261+
fill = "Feature"
262+
)
263+
264+
return(p)
265+
}
266+
267+
plot_result <- function(result, img, path = NULL, name = "res", num_images = 7, as_force = TRUE) {
268+
df_img <- result[[1]]
269+
df_tab <- result[[2]]
270+
271+
# Get prediction data.table
272+
col_idx <- colnames(df_tab)[colnames(df_tab) %in% c("time", "pred", "pred_diff")]
273+
df_pred <- unique(df_tab[, ..col_idx])
274+
275+
# Summarize image features
276+
df_img_sum <- df_img %>%
277+
group_by(id, time, method) %>%
278+
summarise(value = sum(value))
279+
280+
# Add image summary to table
281+
df_tab <- rbind(df_tab,
282+
cbind(df_img_sum, feature = "Image (sum)"), fill = TRUE)
283+
df_tab$feature <- factor(df_tab$feature, levels = unique(df_tab$feature),
284+
labels = c("Age (43)", "Sex (male)", "IDH Mutation (mutant)", "1p19q Codeletion (no)", "Image (sum)"))
285+
286+
# Plot force plot ------------------------------------------------------------
287+
if (as_force) {
288+
p_force <- plot_force_real(df_tab, label = "GradSHAP(t)", num_samples = 20)
289+
} else {
290+
p_force <- NULL
291+
}
292+
293+
# Plot bar plot --------------------------------------------------------------
294+
p_bar <- plot_bars_real(df_tab, label = "GradSHAP(t)")
295+
296+
# Plot as lines
297+
if ("pred_diff" %in% colnames(df_pred)) {
298+
df_pred$pred <- df_pred$pred_diff
299+
}
300+
p_line <- ggplot(df_tab, aes(x = time, y = value, color = feature)) +
301+
geom_line(linewidth = 1) +
302+
geom_line(data = df_pred, aes(y = pred), color = "black", linewidth = 1) +
303+
geom_hline(yintercept = 0, linetype = "dashed", linewidth = 1) +
304+
facet_wrap(vars(id), scales = "free_x") +
305+
scale_color_viridis_d() +
306+
theme_minimal(base_size = 13) +
307+
theme(legend.position = "bottom",
308+
legend.box.margin = margin(),
309+
plot.margin = margin()) +
310+
labs(x = "Time (months)", y = NULL, color = "Feature")
311+
312+
# Aggregate over channels
313+
df_img <- df_img %>%
314+
group_by(id, time, height, width) %>%
315+
summarise(value = sum(value), pred = unique(pred))
316+
317+
# Normalize over time
318+
fun <- function(x) {
319+
q1 <- quantile(x, 0.005)
320+
q2 <- quantile(x, 0.995)
321+
pmax(pmin(x, q2), q1)
322+
}
323+
324+
df_img <- df_img %>%
325+
group_by(id, time) %>%
326+
mutate(value = fun(value))
327+
328+
# Plot image explanation
329+
time_bins <- unique(df_img$time)
330+
times <- as.integer(seq(0, length(time_bins), length.out = num_images + 2))[-c(1, num_images + 2)]
331+
time_bins <- time_bins[times]
332+
333+
df <- df_img[df_img$time %in% time_bins, ]
334+
p_img_exp <- ggplot() +
335+
geom_tile(aes(x = width, y = height, fill = value), data = df) +
336+
facet_grid(cols = vars(time)) +
337+
scale_x_discrete(expand = c(0,0)) +
338+
scale_y_discrete(expand = c(0,0)) +
339+
scale_fill_gradient2(low = "blue", mid = "white",
340+
high = "red", transform = "pseudo_log") +
341+
theme_minimal() +
342+
guides(fill = "none") +
343+
labs(x = NULL, y = NULL) +
344+
theme(axis.text = element_blank(),
345+
plot.margin = margin(0, 0, 0, 0),
346+
strip.text = element_blank())
347+
348+
# Load image
349+
library(ggmap)
350+
p_img <- ggimage(img)
351+
352+
# Save figures
353+
if (!is.null(path)) {
354+
if (!dir.exists(dirname(path))) {
355+
dir.create(dirname(path))
356+
}
357+
ggsave(p_bar + scale_x_continuous(expand = c(0,0)), filename = paste0(path, name, "_bar.pdf"), width = 10, height = 5)
358+
ggsave(p_img_exp, filename = paste0(path, name, "_img_exp.pdf"), width = 7, height = 1)
359+
ggsave(p_img, filename = paste0(path, name, "_img_orig.pdf"), width = 1, height = 1)
360+
if (!is.null(p_force)) {
361+
ggsave(p_force, filename = paste0(path, name, "_force.pdf"), width = 9, height = 5)
362+
}
363+
}
364+
365+
list(p_bar, p_img_exp, p_img, p_force)
366+
}
367+

0 commit comments

Comments
 (0)