|
| 1 | + |
| 2 | + |
| 3 | +################################################################################ |
| 4 | +# R torch Helper functions for the multi-modal model |
| 5 | +################################################################################ |
| 6 | + |
| 7 | +# Define R torch version of the multi-modal model ------------------------------ |
| 8 | + |
| 9 | +# Define multi modal model (in torch we need to re-build the model because |
| 10 | +# only the state dict is saved) |
| 11 | +MultiModalModel <- nn_module( |
| 12 | + "MultiModalModel", |
| 13 | + initialize = function(net_images, tabular_features, n_out = 1, n_img_out = 64, |
| 14 | + include_time = FALSE, out_bias = FALSE) { |
| 15 | + self$net_images <- net_images |
| 16 | + |
| 17 | + input_dim <- n_img_out + length(tabular_features) + ifelse(include_time, 1, 0) |
| 18 | + self$fc1 <- nn_linear(input_dim, 256) |
| 19 | + self$fc2 <- nn_linear(256, 128) |
| 20 | + self$relu <- nn_relu() |
| 21 | + self$drop_0_3 <- nn_dropout(0.3) |
| 22 | + self$drop_0_4 <- nn_dropout(0.3) |
| 23 | + self$out <- nn_linear(128, n_out, bias = out_bias) |
| 24 | + }, |
| 25 | + |
| 26 | + forward = function(input, time = NULL) { |
| 27 | + img <- input[[1]] |
| 28 | + tab <- input[[2]] |
| 29 | + img <- self$net_images(img) |
| 30 | + img <- self$drop_0_4(img) |
| 31 | + |
| 32 | + if (!is.null(time)) { |
| 33 | + x <- torch_cat(list(img, tab, time), dim = 2) |
| 34 | + } else { |
| 35 | + x <- torch_cat(list(img, tab), dim = 2) |
| 36 | + } |
| 37 | + |
| 38 | + x <- self$drop_0_3(x) |
| 39 | + x <- self$relu(self$fc1(x)) |
| 40 | + x <- self$drop_0_3(x) |
| 41 | + x <- self$relu(self$fc2(x)) |
| 42 | + x <- self$out(x) |
| 43 | + |
| 44 | + x |
| 45 | + } |
| 46 | +) |
| 47 | + |
| 48 | +################################################################################ |
| 49 | +# SurvSHAP utility functions |
| 50 | +################################################################################ |
| 51 | + |
| 52 | +# Preprocess function for the torch-based multi-modal model |
| 53 | +predict_survival_function <- function(model, newdata, times) { |
| 54 | + input <- torch_tensor(as.matrix(newdata)) |
| 55 | + input_img <- input[, seq(1, dim(input)[2] - 4)]$view(c(-1, 3, 32, 32)) |
| 56 | + input_tab <- input[, seq(dim(input)[2] - 3, dim(input)[2])] |
| 57 | + input <- list(input_img, input_tab) |
| 58 | + input <- model$preprocess_fun(input) |
| 59 | + out <- model(input, target = "survival")[[1]] |
| 60 | + out <- out$squeeze(dim = seq_len(out$dim())[-c(1, out$dim())]) |
| 61 | + as.array(out) |
| 62 | +} |
| 63 | + |
| 64 | +# Predict function for the torch-based multi-modal model |
| 65 | +predict_function <- function(model, newdata) { |
| 66 | + input <- torch_tensor(as.matrix(newdata)) |
| 67 | + input_img <- input[, seq(1, dim(input)[2] - 4)]$view(c(-1, 3, 32, 32)) |
| 68 | + input_tab <- input[, seq(dim(input)[2] - 3, dim(input)[2])] |
| 69 | + input <- list(input_img, input_tab) |
| 70 | + input <- model$preprocess_fun(input) |
| 71 | + out <- model(input, target = "survival")[[1]] |
| 72 | + out <- out$squeeze(dim = seq_len(out$dim())[-c(1, out$dim())]) |
| 73 | + as.array(out) |
| 74 | +} |
| 75 | + |
| 76 | +################################################################################ |
| 77 | +# Plotting functions |
| 78 | +################################################################################ |
| 79 | + |
| 80 | +# Create force plot |
| 81 | +plot_force_real <- function(x, num_samples = 10, zero_feature = "Sex (male)", label = "") { |
| 82 | + # Convert input to a data frame and calculate derived columns |
| 83 | + dat <- as.data.frame(x) |
| 84 | + dat$id <- as.factor(dat$id) |
| 85 | + dat$feature <- factor( |
| 86 | + dat$feature, |
| 87 | + levels = c( |
| 88 | + "IDH Mutation (mutant)", |
| 89 | + "Image (sum)", |
| 90 | + "1p19q Codeletion (no)", |
| 91 | + "Age (43)", |
| 92 | + "Sex (male)" |
| 93 | + ) |
| 94 | + ) |
| 95 | + |
| 96 | + # Process data to compute sum of attributions |
| 97 | + dat <- dat %>% |
| 98 | + group_by(id, time) %>% |
| 99 | + mutate(sum = sum(value)) |
| 100 | + |
| 101 | + # Sample time points for visualization |
| 102 | + t_interest <- sort(unique(dat$time)) |
| 103 | + target_points <- seq(min(t_interest), max(t_interest), length.out = num_samples) |
| 104 | + selected_points <- sapply(target_points, function(x) |
| 105 | + t_interest[which.min(abs(t_interest - x))]) |
| 106 | + dat_small <- dat[dat$time %in% selected_points, ] |
| 107 | + |
| 108 | + |
| 109 | + # Create position variable for plotting attribution values |
| 110 | + dat_small <- dat_small %>% |
| 111 | + group_by(id, time) %>% |
| 112 | + mutate( |
| 113 | + pos = case_when( |
| 114 | + feature == "Sex (male)" ~ NA, |
| 115 | + feature == "Age (43)" ~ value + value[feature == "Sex (male)"], |
| 116 | + feature == "1p19q Codeletion (no)" ~ value, |
| 117 | + feature == "IDH Mutation (mutant)" ~ value + value[feature == "Age (43)"] + value[feature == "Sex (male)"], |
| 118 | + #feature == "IDH Mutation (mutant)" & |
| 119 | + # sign(value[feature == "Image (sum)"]) < 0 ~ value + value[feature == "Age (43)"] + value[feature == "Sex (male)"], |
| 120 | + #feature == "Image (sum)" & |
| 121 | + # sign(value) < 0 ~ value, |
| 122 | + feature == "Image (sum)" ~ value + value[feature == "1p19q Codeletion (no)"], |
| 123 | + TRUE ~ NA_real_ |
| 124 | + ) |
| 125 | + ) %>% |
| 126 | + ungroup() |
| 127 | + |
| 128 | + # Additional position variable for the arrows |
| 129 | + dat_small$pos_a <- ifelse(dat_small$pos > 0, dat_small$pos + 0.03, dat_small$pos) |
| 130 | + dat_small$pos_a <- ifelse(dat_small$pos_a < 0, dat_small$pos_a - 0.03, dat_small$pos_a) |
| 131 | + |
| 132 | + # Adjust label position manually |
| 133 | + # dat_small[(round(dat_small$time, 2) == 15.21) & |
| 134 | + # (dat_small$feature == "Image (sum)"), "pos"] <- -0.011 |
| 135 | + dat_small[(round(dat_small$value, 2) == 0.01), "pos"] <- 0.018 |
| 136 | + |
| 137 | + # Plot |
| 138 | + p <- ggplot() + |
| 139 | + geom_bar( |
| 140 | + data = dat_small, |
| 141 | + mapping = aes( |
| 142 | + x = .data$time, |
| 143 | + y = .data$value, |
| 144 | + fill = .data$feature, |
| 145 | + color = .data$feature |
| 146 | + ), |
| 147 | + stat = "identity", |
| 148 | + position = "stack" |
| 149 | + ) + |
| 150 | + scale_color_viridis_d(name = "Feature", guide = guide_legend(override.aes = list(linewidth = 7))) + |
| 151 | + scale_fill_viridis_d(alpha = 0.4, name = "Feature") + |
| 152 | + geom_segment( |
| 153 | + data = dat_small[(dat_small$feature != zero_feature) & |
| 154 | + (round(dat_small$value, 2) != 0), ], |
| 155 | + mapping = aes( |
| 156 | + x = .data$time, |
| 157 | + xend = .data$time, |
| 158 | + y = .data$pos_a, |
| 159 | + yend = .data$pos_a + (.data$value) * 0.01, |
| 160 | + color = .data$feature |
| 161 | + ), |
| 162 | + arrow = arrow(type = "closed", length = unit(0.1, "inches")), |
| 163 | + linewidth = 6 |
| 164 | + ) + |
| 165 | + geom_label( |
| 166 | + data = dat_small[round(dat_small$value, 2) != 0, ], |
| 167 | + mapping = aes( |
| 168 | + x = .data$time, |
| 169 | + y = .data$pos, |
| 170 | + label = round(.data$value, 2) |
| 171 | + ), |
| 172 | + color = "black", |
| 173 | + size = 3, |
| 174 | + vjust = 0.5, |
| 175 | + hjust = 0.5, |
| 176 | + na.rm = TRUE |
| 177 | + ) + |
| 178 | + geom_line( |
| 179 | + data = dat, |
| 180 | + mapping = aes(x = .data$time, y = .data$sum), |
| 181 | + color = "black", |
| 182 | + linewidth = 1.5 |
| 183 | + ) + |
| 184 | + facet_wrap(vars(.data$id), |
| 185 | + scales = "free_x", |
| 186 | + labeller = as_labeller(function(a) |
| 187 | + paste0("Instance ID: ", a))) + |
| 188 | + theme_minimal(base_size = 13) + |
| 189 | + theme(legend.position = "bottom") + |
| 190 | + ylim(-0.15, 0.16) + |
| 191 | + scale_x_continuous(expand = c(0,0)) + |
| 192 | + labs( |
| 193 | + x = "Time", |
| 194 | + y = paste0("Force Plot: ", label), |
| 195 | + color = "Feature", |
| 196 | + fill = "Feature" |
| 197 | + ) |
| 198 | + |
| 199 | + return(p) |
| 200 | +} |
| 201 | + |
| 202 | +# Plot function |
| 203 | +plot_bars_real <- function(x, num_samples = 40, zero_feature = "Sex (male)", label = "") { |
| 204 | + # Convert input to a data frame and calculate derived columns |
| 205 | + dat <- as.data.frame(x) |
| 206 | + dat$id <- as.factor(dat$id) |
| 207 | + dat$feature <- factor( |
| 208 | + dat$feature, |
| 209 | + levels = c( |
| 210 | + "IDH Mutation (mutant)", |
| 211 | + "Image (sum)", |
| 212 | + "1p19q Codeletion (no)", |
| 213 | + "Age (43)", |
| 214 | + "Sex (male)" |
| 215 | + ) |
| 216 | + ) |
| 217 | + |
| 218 | + # Process data to compute sum of attributions |
| 219 | + dat <- dat %>% |
| 220 | + group_by(id, time) %>% |
| 221 | + mutate(sum = sum(value)) |
| 222 | + |
| 223 | + # Sample time points for visualization |
| 224 | + t_interest <- sort(unique(dat$time)) |
| 225 | + target_points <- seq(min(t_interest), max(t_interest), length.out = num_samples) |
| 226 | + selected_points <- sapply(target_points, function(x) |
| 227 | + t_interest[which.min(abs(t_interest - x))]) |
| 228 | + dat_small <- dat[dat$time %in% selected_points, ] |
| 229 | + |
| 230 | + # Plot |
| 231 | + p <- ggplot() + |
| 232 | + geom_line( |
| 233 | + data = dat, |
| 234 | + mapping = aes(x = .data$time, y = .data$sum), |
| 235 | + color = "black", |
| 236 | + linewidth = 1 |
| 237 | + ) + |
| 238 | + geom_bar( |
| 239 | + data = dat_small, |
| 240 | + mapping = aes( |
| 241 | + x = .data$time, |
| 242 | + y = .data$value, |
| 243 | + fill = .data$feature, |
| 244 | + color = .data$feature |
| 245 | + ), |
| 246 | + stat = "identity", |
| 247 | + position = "stack" |
| 248 | + ) + |
| 249 | + scale_color_viridis_d(name = "Feature") + |
| 250 | + scale_fill_viridis_d(alpha = 0.4, name = "Feature") + |
| 251 | + facet_wrap(vars(.data$id), |
| 252 | + scales = "free_x", |
| 253 | + labeller = as_labeller(function(a) |
| 254 | + paste0("Instance ID: ", a))) + |
| 255 | + theme_minimal(base_size = 15) + |
| 256 | + theme(legend.position = "bottom") + |
| 257 | + labs( |
| 258 | + x = "Time", |
| 259 | + y = paste0("Contribution: ", label), |
| 260 | + color = "Feature", |
| 261 | + fill = "Feature" |
| 262 | + ) |
| 263 | + |
| 264 | + return(p) |
| 265 | +} |
| 266 | + |
| 267 | +plot_result <- function(result, img, path = NULL, name = "res", num_images = 7, as_force = TRUE) { |
| 268 | + df_img <- result[[1]] |
| 269 | + df_tab <- result[[2]] |
| 270 | + |
| 271 | + # Get prediction data.table |
| 272 | + col_idx <- colnames(df_tab)[colnames(df_tab) %in% c("time", "pred", "pred_diff")] |
| 273 | + df_pred <- unique(df_tab[, ..col_idx]) |
| 274 | + |
| 275 | + # Summarize image features |
| 276 | + df_img_sum <- df_img %>% |
| 277 | + group_by(id, time, method) %>% |
| 278 | + summarise(value = sum(value)) |
| 279 | + |
| 280 | + # Add image summary to table |
| 281 | + df_tab <- rbind(df_tab, |
| 282 | + cbind(df_img_sum, feature = "Image (sum)"), fill = TRUE) |
| 283 | + df_tab$feature <- factor(df_tab$feature, levels = unique(df_tab$feature), |
| 284 | + labels = c("Age (43)", "Sex (male)", "IDH Mutation (mutant)", "1p19q Codeletion (no)", "Image (sum)")) |
| 285 | + |
| 286 | + # Plot force plot ------------------------------------------------------------ |
| 287 | + if (as_force) { |
| 288 | + p_force <- plot_force_real(df_tab, label = "GradSHAP(t)", num_samples = 20) |
| 289 | + } else { |
| 290 | + p_force <- NULL |
| 291 | + } |
| 292 | + |
| 293 | + # Plot bar plot -------------------------------------------------------------- |
| 294 | + p_bar <- plot_bars_real(df_tab, label = "GradSHAP(t)") |
| 295 | + |
| 296 | + # Plot as lines |
| 297 | + if ("pred_diff" %in% colnames(df_pred)) { |
| 298 | + df_pred$pred <- df_pred$pred_diff |
| 299 | + } |
| 300 | + p_line <- ggplot(df_tab, aes(x = time, y = value, color = feature)) + |
| 301 | + geom_line(linewidth = 1) + |
| 302 | + geom_line(data = df_pred, aes(y = pred), color = "black", linewidth = 1) + |
| 303 | + geom_hline(yintercept = 0, linetype = "dashed", linewidth = 1) + |
| 304 | + facet_wrap(vars(id), scales = "free_x") + |
| 305 | + scale_color_viridis_d() + |
| 306 | + theme_minimal(base_size = 13) + |
| 307 | + theme(legend.position = "bottom", |
| 308 | + legend.box.margin = margin(), |
| 309 | + plot.margin = margin()) + |
| 310 | + labs(x = "Time (months)", y = NULL, color = "Feature") |
| 311 | + |
| 312 | + # Aggregate over channels |
| 313 | + df_img <- df_img %>% |
| 314 | + group_by(id, time, height, width) %>% |
| 315 | + summarise(value = sum(value), pred = unique(pred)) |
| 316 | + |
| 317 | + # Normalize over time |
| 318 | + fun <- function(x) { |
| 319 | + q1 <- quantile(x, 0.005) |
| 320 | + q2 <- quantile(x, 0.995) |
| 321 | + pmax(pmin(x, q2), q1) |
| 322 | + } |
| 323 | + |
| 324 | + df_img <- df_img %>% |
| 325 | + group_by(id, time) %>% |
| 326 | + mutate(value = fun(value)) |
| 327 | + |
| 328 | + # Plot image explanation |
| 329 | + time_bins <- unique(df_img$time) |
| 330 | + times <- as.integer(seq(0, length(time_bins), length.out = num_images + 2))[-c(1, num_images + 2)] |
| 331 | + time_bins <- time_bins[times] |
| 332 | + |
| 333 | + df <- df_img[df_img$time %in% time_bins, ] |
| 334 | + p_img_exp <- ggplot() + |
| 335 | + geom_tile(aes(x = width, y = height, fill = value), data = df) + |
| 336 | + facet_grid(cols = vars(time)) + |
| 337 | + scale_x_discrete(expand = c(0,0)) + |
| 338 | + scale_y_discrete(expand = c(0,0)) + |
| 339 | + scale_fill_gradient2(low = "blue", mid = "white", |
| 340 | + high = "red", transform = "pseudo_log") + |
| 341 | + theme_minimal() + |
| 342 | + guides(fill = "none") + |
| 343 | + labs(x = NULL, y = NULL) + |
| 344 | + theme(axis.text = element_blank(), |
| 345 | + plot.margin = margin(0, 0, 0, 0), |
| 346 | + strip.text = element_blank()) |
| 347 | + |
| 348 | + # Load image |
| 349 | + library(ggmap) |
| 350 | + p_img <- ggimage(img) |
| 351 | + |
| 352 | + # Save figures |
| 353 | + if (!is.null(path)) { |
| 354 | + if (!dir.exists(dirname(path))) { |
| 355 | + dir.create(dirname(path)) |
| 356 | + } |
| 357 | + ggsave(p_bar + scale_x_continuous(expand = c(0,0)), filename = paste0(path, name, "_bar.pdf"), width = 10, height = 5) |
| 358 | + ggsave(p_img_exp, filename = paste0(path, name, "_img_exp.pdf"), width = 7, height = 1) |
| 359 | + ggsave(p_img, filename = paste0(path, name, "_img_orig.pdf"), width = 1, height = 1) |
| 360 | + if (!is.null(p_force)) { |
| 361 | + ggsave(p_force, filename = paste0(path, name, "_force.pdf"), width = 9, height = 5) |
| 362 | + } |
| 363 | + } |
| 364 | + |
| 365 | + list(p_bar, p_img_exp, p_img, p_force) |
| 366 | +} |
| 367 | + |
0 commit comments