Skip to content

Commit

Permalink
Merge pull request #430 from mlr-org/calibration_index
Browse files Browse the repository at this point in the history
Add ICI measure + smoothed calibration plot
  • Loading branch information
bblodfon authored Jan 6, 2025
2 parents 48dc1f8 + 46b4ddd commit 273911d
Show file tree
Hide file tree
Showing 46 changed files with 824 additions and 375 deletions.
81 changes: 29 additions & 52 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,49 +1,25 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.7.1
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
role = "aut",
email = "[email protected]",
comment = c(ORCID = "0000-0001-9225-4654")),
person(given = "Franz",
family = "Kiraly",
role = "aut",
email = "[email protected]"),
person(given = "Michel",
family = "Lang",
role = "aut",
email = "[email protected]",
comment = c(ORCID = "0000-0001-9754-0393")),
person(given = "Nurul Ain",
family = "Toha",
role = "ctb",
email = "[email protected]"),
person(given = "Andreas",
family = "Bender",
role = "ctb",
email = "[email protected]",
comment = c(ORCID = "0000-0001-5628-8611")),
person(given = "John",
family = "Zobolas",
role = c("cre", "aut"),
email = "[email protected]",
comment = c(ORCID = "0000-0002-3609-8674")),
person(given = "Lukas",
family = "Burk",
email = "[email protected]",
role = "ctb",
comment = c(ORCID = "0000-0001-7528-3795")),
person(given = "Philip",
family = "Studener",
role = "aut",
email = "[email protected]"),
person(given = "Maximilian",
family = "Muecke",
email = "[email protected]",
role = "ctb",
comment = c(ORCID = "0009-0000-9432-9795")))
Version: 0.7.3
Authors@R: c(
person("Raphael", "Sonabend", , "[email protected]", role = "aut",
comment = c(ORCID = "0000-0001-9225-4654")),
person("Franz", "Kiraly", , "[email protected]", role = "aut"),
person("Michel", "Lang", , "[email protected]", role = "aut",
comment = c(ORCID = "0000-0001-9754-0393")),
person("Nurul Ain", "Toha", , "[email protected]", role = "ctb"),
person("Andreas", "Bender", , "[email protected]", role = "ctb",
comment = c(ORCID = "0000-0001-5628-8611")),
person("John", "Zobolas", , "[email protected]", role = c("cre", "aut"),
comment = c(ORCID = "0000-0002-3609-8674")),
person("Lukas", "Burk", , "[email protected]", role = "ctb",
comment = c(ORCID = "0000-0001-7528-3795")),
person("Philip", "Studener", , "[email protected]", role = "aut"),
person("Maximilian", "Muecke", , "[email protected]", role = "ctb",
comment = c(ORCID = "0009-0000-9432-9795")),
person("Lee Xingzhuo", "Li", , "[email protected]", role = "ctb",
comment = c(ORCID = "0000-0001-5259-5198"))
)
Description: Provides extensions for probabilistic supervised learning for
'mlr3'. This includes extending the regression task to probabilistic
and interval regression, adding a survival task, and other specialized
Expand All @@ -61,35 +37,36 @@ Imports:
ggplot2,
mlr3misc (>= 0.7.0),
mlr3pipelines (>= 0.7.0),
mlr3viz,
paradox (>= 1.0.0),
R6,
Rcpp (>= 1.0.4),
survival
Suggests:
abind,
coxed,
GGally,
knitr,
lgr,
lifecycle,
mlr3learners,
mlr3viz,
pammtools,
param6 (>= 0.2.4),
polspline,
pracma,
rpart,
set6 (>= 0.2.6),
simsurv,
survAUC,
testthat (>= 3.0.0),
abind,
coxed,
mlr3learners,
pammtools
testthat (>= 3.0.0)
LinkingTo:
Rcpp
Remotes:
xoopR/distr6,
xoopR/param6,
xoopR/set6
Config/testthat/edition: 3
ByteCompile: true
Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
NeedsCompilation: no
Expand All @@ -116,6 +93,7 @@ Collate:
'MeasureSurvDCalibration.R'
'MeasureSurvGraf.R'
'MeasureSurvHungAUC.R'
'MeasureSurvICI.R'
'MeasureSurvIntLogloss.R'
'MeasureSurvLogloss.R'
'MeasureSurvMAE.R'
Expand Down Expand Up @@ -170,7 +148,6 @@ Collate:
'mlr3proba-package.R'
'pecs.R'
'pipelines.R'
'plot.R'
'plot_probregr.R'
'scoring_rule_erv.R'
'surv_measures.R'
Expand Down
4 changes: 1 addition & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ S3method(is_missing_prediction_data,PredictionDataDens)
S3method(is_missing_prediction_data,PredictionDataSurv)
S3method(pecs,PredictionSurv)
S3method(pecs,list)
S3method(plot,LearnerSurv)
S3method(plot,TaskDens)
S3method(plot,TaskSurv)
export(.c_weight_survival_score)
Expand All @@ -51,6 +50,7 @@ export(MeasureSurvCindex)
export(MeasureSurvDCalibration)
export(MeasureSurvGraf)
export(MeasureSurvHungAUC)
export(MeasureSurvICI)
export(MeasureSurvIntLogloss)
export(MeasureSurvLogloss)
export(MeasureSurvMAE)
Expand Down Expand Up @@ -105,15 +105,13 @@ import(mlr3misc)
import(paradox)
importFrom(R6,R6Class)
importFrom(Rcpp,sourceCpp)
importFrom(graphics,plot)
importFrom(mlr3pipelines,"%>>%")
importFrom(mlr3pipelines,Graph)
importFrom(mlr3pipelines,as_graph)
importFrom(mlr3pipelines,gunion)
importFrom(mlr3pipelines,pipeline_greplicate)
importFrom(mlr3pipelines,po)
importFrom(mlr3pipelines,ppl)
importFrom(mlr3viz,fortify)
importFrom(stats,density)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
Expand Down
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# mlr3proba 0.7.3

* feat: added new calibration measure => `msr("surv.calib_index")`
* refactor + feat: `autoplot.PredictionSurv`
* The default `"calib"` plot uses the survival matrix directly now which is faster
* `"dcalib"` has extra barplot + better documentation
* Added new `type = "scalib"` which constructs the smoothed calibration plots as in Austin et al. (2020)
* **BREAKING CHANGE**: `"preds"` is now called `"isd"` (individual survival distribution). `row_ids` can now be used to filter the observations for which you draw the survival curves.

# mlr3proba 0.7.2

* fix: `lrn("surv.coxph")` is now trained with `model=TRUE` which fixes an issue with using observation weights [stackoverflow link](https://stackoverflow.com/questions/79297386/mlr3-predicted-values-for-surv-coxph-learner-with-case-weights).
Expand Down
1 change: 0 additions & 1 deletion R/LearnerDens.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#' @template param_predict_types
#' @template param_feature_types
#' @template param_learner_properties
#' @template param_data_formats
#' @template param_packages
#' @template param_label
#' @template param_man
Expand Down
1 change: 0 additions & 1 deletion R/LearnerSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#' @template param_predict_types
#' @template param_feature_types
#' @template param_learner_properties
#' @template param_data_formats
#' @template param_packages
#' @template param_label
#' @template param_man
Expand Down
159 changes: 159 additions & 0 deletions R/MeasureSurvICI.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#' @template surv_measure
#' @templateVar title Integrated Calibration Index
#' @templateVar fullname MeasureSurvICI
#' @templateVar eps 1e-4
#' @template param_eps
#'
#' @description
#' Calculates the Integrated Calibration Index (ICI), see Austin et al. (2020).
#'
#' @details
#' Each individual \eqn{i} from the test set, has an observed survival outcome
#' \eqn{(t_i, \delta_i)} (time and censoring indicator) and predicted survival
#' function \eqn{S_i(t)}.
#' The predicted probability of an event occurring before a specific time point
#' \eqn{t_0}, is defined as \eqn{\hat{P_i}(t_0) = F_i(t_0) = 1 - S_i(t_0)}.
#'
#' Using hazard regression (via the \CRANpkg{polspline} R package), a *smoothed*
#' calibration curve is estimated by fitting the following model:
#' \deqn{log(h(t)) = g(log(− log(1 − \hat{P}_{t_0})), t)}
#'
#' Note that we substitute probabilities \eqn{\hat{P}_{t_0} = 0} with a small
#' \eqn{\epsilon} number to avoid arithmetic issues (\eqn{log(0)}). Same with
#' \eqn{\hat{P}_{t_0} = 1}, we use \eqn{1 - \epsilon}.
#' From this model, the *smoothed* probability of occurrence at \eqn{t_0} for
#' observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}.
#'
#' The **Integrated Calibration Index** is then computed across the \eqn{N}
#' test set observations as:
#' \deqn{ICI = \frac{1}{N} \sum_{i=1}^N | \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |}
#'
#' This measure evaluates **point-calibration** at a specific time point, which
#' must be specified by the user.
#'
#' @section Parameter details:
#' - `time` (`numeric(1)`)\cr
#' The specific time point \eqn{t_0} at which calibration is evaluated.
#' If `NULL`, the median observed time from the test set is used.
#' - `method` (`character(1)`)\cr
#' Specifies the summary statistic used to calculate the final calibration score.
#' - `"ICI"` (default): Uses the mean of absolute differences \eqn{| \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |} across all observations.
#' - `"E50"`: Uses the median of absolute differences instead of the mean.
#' - `"E90"`: Uses the 90th percentile of absolute differences, emphasizing higher deviations.
#' - `"Emax"`: Uses the maximum absolute difference, capturing the largest discrepancy between predicted and smoothed probabilities.
#'
#' @references
#' `r format_bib("austin2020")`
#'
#' @family calibration survival measures
#' @family distr survival measures
#' @examples
#' library(mlr3)
#'
#' # Define a survival Task
#' task = tsk("lung")
#'
#' # Create train and test set
#' part = partition(task)
#'
#' # Train Cox learner on the train set
#' cox = lrn("surv.coxph")
#' cox$train(task, row_ids = part$train)
#'
#' # Make predictions for the test set
#' p = cox$predict(task, row_ids = part$test)
#'
#' # ICI at median test set time
#' p$score(msr("surv.calib_index"))
#'
#' # ICI at specific time point
#' p$score(msr("surv.calib_index", time = 365))
#'
#' # E50 at specific time point
#' p$score(msr("surv.calib_index", method = "E50", time = 365))
#'
#' @export
MeasureSurvICI = R6Class("MeasureSurvICI",
inherit = MeasureSurv,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
time = p_dbl(0, Inf),
eps = p_dbl(0, 1, default = 1e-4),
method = p_fct(default = "ICI", levels = c("ICI", "E50", "E90", "Emax"))
)
param_set$set_values(method = "ICI", eps = 1e-4)

super$initialize(
id = "surv.calib_index",
packages = c("polspline"),
range = c(0, Inf),
minimize = TRUE,
predict_type = "distr",
label = "Integrated Calibration Index",
man = "mlr3proba::mlr_measures_surv.calib_index",
param_set = param_set
)
}
),

private = list(
.score = function(prediction, ...) {
# test set survival outcome
times = prediction$truth[, 1L]
status = prediction$truth[, 2L]

# get predicted survival matrix
if (inherits(prediction$data$distr, "array")) {
surv = prediction$data$distr
if (length(dim(surv)) == 3L) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
} else {
stop("Distribution prediction does not have a survival matrix or array
in the $data$distr slot")
}

pv = self$param_set$values

# time point for calibration
time = pv$time %??% stats::median(times)

# get cdf at the specified time point
extend_times_cdf = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6")
pred_times = as.numeric(colnames(surv))
cdf = as.vector(extend_times_cdf(time, pred_times, cdf = t(1 - surv), TRUE, FALSE))
# to avoid log(0) later, same as in paper's Appendix
eps = pv$eps
cdf[cdf == 1] = 1 - eps
cdf[cdf == 0] = eps

# get the cdf complement (survival) log-log transformed
cll = log(-log(1 - cdf))

hare_fit = polspline::hare(data = times, delta = status, cov = as.matrix(cll))
smoothed_cdf = polspline::phare(q = time, cov = cll, fit = hare_fit)

method = pv$method
if (method == "ICI") {
# Mean difference (ICI)
result = mean(abs(cdf - smoothed_cdf))
} else if (method == "E50") {
# Median (E50)
result = stats::median(abs(cdf - smoothed_cdf))
} else if (method == "E90") {
# 90th percentile (E90)
result = stats::quantile(abs(cdf - smoothed_cdf), probs = 0.9)
} else if (method == "Emax") {
# Maximum absolute difference (Emax)
result = max(abs(cdf - smoothed_cdf))
}

result
}
)
)

register_measure("surv.calib_index", MeasureSurvICI)
1 change: 1 addition & 0 deletions R/PipeOpTaskSurvClassifDiscTime.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ PipeOpTaskSurvClassifDiscTime = R6Class("PipeOpTaskSurvClassifDiscTime",
)
super$initialize(
id = id,
packages = c("pammtools"),
param_set = param_set,
input = data.table(
name = "input",
Expand Down
Loading

0 comments on commit 273911d

Please sign in to comment.