-
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #430 from mlr-org/calibration_index
Add ICI measure + smoothed calibration plot
- Loading branch information
Showing
46 changed files
with
824 additions
and
375 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,25 @@ | ||
Package: mlr3proba | ||
Title: Probabilistic Supervised Learning for 'mlr3' | ||
Version: 0.7.1 | ||
Authors@R: | ||
c(person(given = "Raphael", | ||
family = "Sonabend", | ||
role = "aut", | ||
email = "[email protected]", | ||
comment = c(ORCID = "0000-0001-9225-4654")), | ||
person(given = "Franz", | ||
family = "Kiraly", | ||
role = "aut", | ||
email = "[email protected]"), | ||
person(given = "Michel", | ||
family = "Lang", | ||
role = "aut", | ||
email = "[email protected]", | ||
comment = c(ORCID = "0000-0001-9754-0393")), | ||
person(given = "Nurul Ain", | ||
family = "Toha", | ||
role = "ctb", | ||
email = "[email protected]"), | ||
person(given = "Andreas", | ||
family = "Bender", | ||
role = "ctb", | ||
email = "[email protected]", | ||
comment = c(ORCID = "0000-0001-5628-8611")), | ||
person(given = "John", | ||
family = "Zobolas", | ||
role = c("cre", "aut"), | ||
email = "[email protected]", | ||
comment = c(ORCID = "0000-0002-3609-8674")), | ||
person(given = "Lukas", | ||
family = "Burk", | ||
email = "[email protected]", | ||
role = "ctb", | ||
comment = c(ORCID = "0000-0001-7528-3795")), | ||
person(given = "Philip", | ||
family = "Studener", | ||
role = "aut", | ||
email = "[email protected]"), | ||
person(given = "Maximilian", | ||
family = "Muecke", | ||
email = "[email protected]", | ||
role = "ctb", | ||
comment = c(ORCID = "0009-0000-9432-9795"))) | ||
Version: 0.7.3 | ||
Authors@R: c( | ||
person("Raphael", "Sonabend", , "[email protected]", role = "aut", | ||
comment = c(ORCID = "0000-0001-9225-4654")), | ||
person("Franz", "Kiraly", , "[email protected]", role = "aut"), | ||
person("Michel", "Lang", , "[email protected]", role = "aut", | ||
comment = c(ORCID = "0000-0001-9754-0393")), | ||
person("Nurul Ain", "Toha", , "[email protected]", role = "ctb"), | ||
person("Andreas", "Bender", , "[email protected]", role = "ctb", | ||
comment = c(ORCID = "0000-0001-5628-8611")), | ||
person("John", "Zobolas", , "[email protected]", role = c("cre", "aut"), | ||
comment = c(ORCID = "0000-0002-3609-8674")), | ||
person("Lukas", "Burk", , "[email protected]", role = "ctb", | ||
comment = c(ORCID = "0000-0001-7528-3795")), | ||
person("Philip", "Studener", , "[email protected]", role = "aut"), | ||
person("Maximilian", "Muecke", , "[email protected]", role = "ctb", | ||
comment = c(ORCID = "0009-0000-9432-9795")), | ||
person("Lee Xingzhuo", "Li", , "[email protected]", role = "ctb", | ||
comment = c(ORCID = "0000-0001-5259-5198")) | ||
) | ||
Description: Provides extensions for probabilistic supervised learning for | ||
'mlr3'. This includes extending the regression task to probabilistic | ||
and interval regression, adding a survival task, and other specialized | ||
|
@@ -61,35 +37,36 @@ Imports: | |
ggplot2, | ||
mlr3misc (>= 0.7.0), | ||
mlr3pipelines (>= 0.7.0), | ||
mlr3viz, | ||
paradox (>= 1.0.0), | ||
R6, | ||
Rcpp (>= 1.0.4), | ||
survival | ||
Suggests: | ||
abind, | ||
coxed, | ||
GGally, | ||
knitr, | ||
lgr, | ||
lifecycle, | ||
mlr3learners, | ||
mlr3viz, | ||
pammtools, | ||
param6 (>= 0.2.4), | ||
polspline, | ||
pracma, | ||
rpart, | ||
set6 (>= 0.2.6), | ||
simsurv, | ||
survAUC, | ||
testthat (>= 3.0.0), | ||
abind, | ||
coxed, | ||
mlr3learners, | ||
pammtools | ||
testthat (>= 3.0.0) | ||
LinkingTo: | ||
Rcpp | ||
Remotes: | ||
xoopR/distr6, | ||
xoopR/param6, | ||
xoopR/set6 | ||
Config/testthat/edition: 3 | ||
ByteCompile: true | ||
Config/testthat/edition: 3 | ||
Encoding: UTF-8 | ||
LazyData: true | ||
NeedsCompilation: no | ||
|
@@ -116,6 +93,7 @@ Collate: | |
'MeasureSurvDCalibration.R' | ||
'MeasureSurvGraf.R' | ||
'MeasureSurvHungAUC.R' | ||
'MeasureSurvICI.R' | ||
'MeasureSurvIntLogloss.R' | ||
'MeasureSurvLogloss.R' | ||
'MeasureSurvMAE.R' | ||
|
@@ -170,7 +148,6 @@ Collate: | |
'mlr3proba-package.R' | ||
'pecs.R' | ||
'pipelines.R' | ||
'plot.R' | ||
'plot_probregr.R' | ||
'scoring_rule_erv.R' | ||
'surv_measures.R' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
#' @template surv_measure | ||
#' @templateVar title Integrated Calibration Index | ||
#' @templateVar fullname MeasureSurvICI | ||
#' @templateVar eps 1e-4 | ||
#' @template param_eps | ||
#' | ||
#' @description | ||
#' Calculates the Integrated Calibration Index (ICI), see Austin et al. (2020). | ||
#' | ||
#' @details | ||
#' Each individual \eqn{i} from the test set, has an observed survival outcome | ||
#' \eqn{(t_i, \delta_i)} (time and censoring indicator) and predicted survival | ||
#' function \eqn{S_i(t)}. | ||
#' The predicted probability of an event occurring before a specific time point | ||
#' \eqn{t_0}, is defined as \eqn{\hat{P_i}(t_0) = F_i(t_0) = 1 - S_i(t_0)}. | ||
#' | ||
#' Using hazard regression (via the \CRANpkg{polspline} R package), a *smoothed* | ||
#' calibration curve is estimated by fitting the following model: | ||
#' \deqn{log(h(t)) = g(log(− log(1 − \hat{P}_{t_0})), t)} | ||
#' | ||
#' Note that we substitute probabilities \eqn{\hat{P}_{t_0} = 0} with a small | ||
#' \eqn{\epsilon} number to avoid arithmetic issues (\eqn{log(0)}). Same with | ||
#' \eqn{\hat{P}_{t_0} = 1}, we use \eqn{1 - \epsilon}. | ||
#' From this model, the *smoothed* probability of occurrence at \eqn{t_0} for | ||
#' observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}. | ||
#' | ||
#' The **Integrated Calibration Index** is then computed across the \eqn{N} | ||
#' test set observations as: | ||
#' \deqn{ICI = \frac{1}{N} \sum_{i=1}^N | \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |} | ||
#' | ||
#' This measure evaluates **point-calibration** at a specific time point, which | ||
#' must be specified by the user. | ||
#' | ||
#' @section Parameter details: | ||
#' - `time` (`numeric(1)`)\cr | ||
#' The specific time point \eqn{t_0} at which calibration is evaluated. | ||
#' If `NULL`, the median observed time from the test set is used. | ||
#' - `method` (`character(1)`)\cr | ||
#' Specifies the summary statistic used to calculate the final calibration score. | ||
#' - `"ICI"` (default): Uses the mean of absolute differences \eqn{| \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |} across all observations. | ||
#' - `"E50"`: Uses the median of absolute differences instead of the mean. | ||
#' - `"E90"`: Uses the 90th percentile of absolute differences, emphasizing higher deviations. | ||
#' - `"Emax"`: Uses the maximum absolute difference, capturing the largest discrepancy between predicted and smoothed probabilities. | ||
#' | ||
#' @references | ||
#' `r format_bib("austin2020")` | ||
#' | ||
#' @family calibration survival measures | ||
#' @family distr survival measures | ||
#' @examples | ||
#' library(mlr3) | ||
#' | ||
#' # Define a survival Task | ||
#' task = tsk("lung") | ||
#' | ||
#' # Create train and test set | ||
#' part = partition(task) | ||
#' | ||
#' # Train Cox learner on the train set | ||
#' cox = lrn("surv.coxph") | ||
#' cox$train(task, row_ids = part$train) | ||
#' | ||
#' # Make predictions for the test set | ||
#' p = cox$predict(task, row_ids = part$test) | ||
#' | ||
#' # ICI at median test set time | ||
#' p$score(msr("surv.calib_index")) | ||
#' | ||
#' # ICI at specific time point | ||
#' p$score(msr("surv.calib_index", time = 365)) | ||
#' | ||
#' # E50 at specific time point | ||
#' p$score(msr("surv.calib_index", method = "E50", time = 365)) | ||
#' | ||
#' @export | ||
MeasureSurvICI = R6Class("MeasureSurvICI", | ||
inherit = MeasureSurv, | ||
public = list( | ||
#' @description Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function() { | ||
param_set = ps( | ||
time = p_dbl(0, Inf), | ||
eps = p_dbl(0, 1, default = 1e-4), | ||
method = p_fct(default = "ICI", levels = c("ICI", "E50", "E90", "Emax")) | ||
) | ||
param_set$set_values(method = "ICI", eps = 1e-4) | ||
|
||
super$initialize( | ||
id = "surv.calib_index", | ||
packages = c("polspline"), | ||
range = c(0, Inf), | ||
minimize = TRUE, | ||
predict_type = "distr", | ||
label = "Integrated Calibration Index", | ||
man = "mlr3proba::mlr_measures_surv.calib_index", | ||
param_set = param_set | ||
) | ||
} | ||
), | ||
|
||
private = list( | ||
.score = function(prediction, ...) { | ||
# test set survival outcome | ||
times = prediction$truth[, 1L] | ||
status = prediction$truth[, 2L] | ||
|
||
# get predicted survival matrix | ||
if (inherits(prediction$data$distr, "array")) { | ||
surv = prediction$data$distr | ||
if (length(dim(surv)) == 3L) { | ||
# survival 3d array, extract median | ||
surv = .ext_surv_mat(arr = surv, which.curve = 0.5) | ||
} | ||
} else { | ||
stop("Distribution prediction does not have a survival matrix or array | ||
in the $data$distr slot") | ||
} | ||
|
||
pv = self$param_set$values | ||
|
||
# time point for calibration | ||
time = pv$time %??% stats::median(times) | ||
|
||
# get cdf at the specified time point | ||
extend_times_cdf = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6") | ||
pred_times = as.numeric(colnames(surv)) | ||
cdf = as.vector(extend_times_cdf(time, pred_times, cdf = t(1 - surv), TRUE, FALSE)) | ||
# to avoid log(0) later, same as in paper's Appendix | ||
eps = pv$eps | ||
cdf[cdf == 1] = 1 - eps | ||
cdf[cdf == 0] = eps | ||
|
||
# get the cdf complement (survival) log-log transformed | ||
cll = log(-log(1 - cdf)) | ||
|
||
hare_fit = polspline::hare(data = times, delta = status, cov = as.matrix(cll)) | ||
smoothed_cdf = polspline::phare(q = time, cov = cll, fit = hare_fit) | ||
|
||
method = pv$method | ||
if (method == "ICI") { | ||
# Mean difference (ICI) | ||
result = mean(abs(cdf - smoothed_cdf)) | ||
} else if (method == "E50") { | ||
# Median (E50) | ||
result = stats::median(abs(cdf - smoothed_cdf)) | ||
} else if (method == "E90") { | ||
# 90th percentile (E90) | ||
result = stats::quantile(abs(cdf - smoothed_cdf), probs = 0.9) | ||
} else if (method == "Emax") { | ||
# Maximum absolute difference (Emax) | ||
result = max(abs(cdf - smoothed_cdf)) | ||
} | ||
|
||
result | ||
} | ||
) | ||
) | ||
|
||
register_measure("surv.calib_index", MeasureSurvICI) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.