From 6ef7b876ab648b58d7be030482032493612cf070 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 2 Jan 2025 19:46:54 +0200 Subject: [PATCH 01/29] remove obsolete 'data_formats' --- R/LearnerDens.R | 1 - R/LearnerSurv.R | 1 - man-roxygen/param_data_formats.R | 3 --- 3 files changed, 5 deletions(-) delete mode 100644 man-roxygen/param_data_formats.R diff --git a/R/LearnerDens.R b/R/LearnerDens.R index fb470929a..6fa426e4e 100644 --- a/R/LearnerDens.R +++ b/R/LearnerDens.R @@ -14,7 +14,6 @@ #' @template param_predict_types #' @template param_feature_types #' @template param_learner_properties -#' @template param_data_formats #' @template param_packages #' @template param_label #' @template param_man diff --git a/R/LearnerSurv.R b/R/LearnerSurv.R index 84bb2ed23..37e71c0cc 100644 --- a/R/LearnerSurv.R +++ b/R/LearnerSurv.R @@ -17,7 +17,6 @@ #' @template param_predict_types #' @template param_feature_types #' @template param_learner_properties -#' @template param_data_formats #' @template param_packages #' @template param_label #' @template param_man diff --git a/man-roxygen/param_data_formats.R b/man-roxygen/param_data_formats.R deleted file mode 100644 index 0b834d54f..000000000 --- a/man-roxygen/param_data_formats.R +++ /dev/null @@ -1,3 +0,0 @@ -#' @param data_formats (`character()`)\cr -#' Set of supported data formats which can be processed during `$train()` and `$predict()`, -#' e.g. `"data.table"`. From 48a8aebd720bc8ee53649294c701dc9c995a0524 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 2 Jan 2025 20:01:26 +0200 Subject: [PATCH 02/29] import 'as_ped' from pammtools --- NAMESPACE | 1 + R/PipeOpTaskSurvClassifDiscTime.R | 1 + R/zzz.R | 1 + 3 files changed, 3 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 2febffe65..12cb7a73a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -114,6 +114,7 @@ importFrom(mlr3pipelines,pipeline_greplicate) importFrom(mlr3pipelines,po) importFrom(mlr3pipelines,ppl) importFrom(mlr3viz,fortify) +importFrom(pammtools,as_ped) importFrom(stats,density) importFrom(stats,model.frame) importFrom(stats,model.matrix) diff --git a/R/PipeOpTaskSurvClassifDiscTime.R b/R/PipeOpTaskSurvClassifDiscTime.R index ed600fede..5e0f3d8d4 100644 --- a/R/PipeOpTaskSurvClassifDiscTime.R +++ b/R/PipeOpTaskSurvClassifDiscTime.R @@ -97,6 +97,7 @@ PipeOpTaskSurvClassifDiscTime = R6Class("PipeOpTaskSurvClassifDiscTime", ) super$initialize( id = id, + packages = c("pammtools"), param_set = param_set, input = data.table( name = "input", diff --git a/R/zzz.R b/R/zzz.R index 356457010..af49732f0 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -28,6 +28,7 @@ NULL #' @importFrom mlr3viz fortify #' @importFrom utils getFromNamespace #' @importFrom mlr3pipelines po as_graph %>>% pipeline_greplicate gunion Graph ppl +#' @importFrom pammtools as_ped "_PACKAGE" # nolint end From 5e7d6c60ecc2d23045ea0950fb7ae293651c7218 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 2 Jan 2025 20:11:55 +0200 Subject: [PATCH 03/29] replace use of fortify() --- NAMESPACE | 1 - R/autoplot.R | 4 ++-- R/zzz.R | 1 - man/autoplot.TaskDens.Rd | 2 +- man/autoplot.TaskSurv.Rd | 2 +- 5 files changed, 4 insertions(+), 6 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 12cb7a73a..970b713f9 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -113,7 +113,6 @@ importFrom(mlr3pipelines,gunion) importFrom(mlr3pipelines,pipeline_greplicate) importFrom(mlr3pipelines,po) importFrom(mlr3pipelines,ppl) -importFrom(mlr3viz,fortify) importFrom(pammtools,as_ped) importFrom(stats,density) importFrom(stats,model.frame) diff --git a/R/autoplot.R b/R/autoplot.R index 236820d98..99543a51c 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -34,8 +34,8 @@ #' library(ggplot2) #' #' task = tsk("lung") +#' task$head() #' -#' head(fortify(task)) #' autoplot(task) # KM #' autoplot(task) # KM of the censoring distribution #' autoplot(task, rhs = "sex") @@ -105,8 +105,8 @@ plot.TaskSurv = function(x, ...) { #' library(mlr3viz) #' library(ggplot2) #' task = tsk("precip") +#' task$head() #' -#' head(fortify(task)) #' autoplot(task, bins = 15) #' autoplot(task, type = "freq", bins = 15) #' autoplot(task, type = "overlay", bins = 15) diff --git a/R/zzz.R b/R/zzz.R index af49732f0..513f32055 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -25,7 +25,6 @@ NULL #' @importFrom utils data head tail #' @importFrom stats model.matrix model.frame sd predict density #' @importFrom survival Surv -#' @importFrom mlr3viz fortify #' @importFrom utils getFromNamespace #' @importFrom mlr3pipelines po as_graph %>>% pipeline_greplicate gunion Graph ppl #' @importFrom pammtools as_ped diff --git a/man/autoplot.TaskDens.Rd b/man/autoplot.TaskDens.Rd index 0df103bd3..68c43f813 100644 --- a/man/autoplot.TaskDens.Rd +++ b/man/autoplot.TaskDens.Rd @@ -38,8 +38,8 @@ library(mlr3proba) library(mlr3viz) library(ggplot2) task = tsk("precip") +task$head() -head(fortify(task)) autoplot(task, bins = 15) autoplot(task, type = "freq", bins = 15) autoplot(task, type = "overlay", bins = 15) diff --git a/man/autoplot.TaskSurv.Rd b/man/autoplot.TaskSurv.Rd index cacb6acc9..18955d518 100644 --- a/man/autoplot.TaskSurv.Rd +++ b/man/autoplot.TaskSurv.Rd @@ -54,8 +54,8 @@ library(mlr3proba) library(ggplot2) task = tsk("lung") +task$head() -head(fortify(task)) autoplot(task) # KM autoplot(task) # KM of the censoring distribution autoplot(task, rhs = "sex") From 561e1f3930447f101dfd35091181cf6d84067099 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 2 Jan 2025 20:52:59 +0200 Subject: [PATCH 04/29] add ICI paper --- R/bibentries.R | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/R/bibentries.R b/R/bibentries.R index 990e77148..afbe72025 100644 --- a/R/bibentries.R +++ b/R/bibentries.R @@ -401,5 +401,20 @@ bibentries = c( url = "http://jmlr.org/papers/v24/19-1030.html", volume = "24", year = "2023" + ), + austin2020 = bibentry("article", + author = "Austin, Peter C. and Harrell, Frank E. and van Klaveren, David", + doi = "10.1002/SIM.8570", + issn = "10970258", + journal = "Statistics in Medicine", + month = "sep", + number = "21", + pages = "2714", + pmid = "32548928", + publisher = "John Wiley and Sons Ltd", + title = "Graphical calibration curves and the integrated calibration index (ICI) for survival models", + url = "https://pmc.ncbi.nlm.nih.gov/articles/PMC7497089/", + volume = "39", + year = "2020" ) ) From 8c362ea0c308ae05a52d0d575094c381987b8ca7 Mon Sep 17 00:00:00 2001 From: john Date: Fri, 3 Jan 2025 02:13:52 +0200 Subject: [PATCH 05/29] refactor test --- tests/testthat/test_mlr_measures.R | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/tests/testthat/test_mlr_measures.R b/tests/testthat/test_mlr_measures.R index 486e3fd3a..ea8b67403 100644 --- a/tests/testthat/test_mlr_measures.R +++ b/tests/testthat/test_mlr_measures.R @@ -1,8 +1,8 @@ set.seed(1L) -task = tsk("rats")$filter(sample(300, 20L)) +task = tsk("rats")$filter(sample(300, 50L)) learner = suppressWarnings(lrn("surv.coxph")$train(task)) pred = learner$predict(task) -pred$data$response = 1:20 +pred$data$response = 1:50 pred$predict_types = c(pred$predict_types, "response") test_that("mlr_measures", { @@ -40,12 +40,9 @@ test_that("mlr_measures", { } }) -learner = suppressWarnings(lrn("surv.coxph")$train(task)) -prediction = learner$predict(task) - test_that("unintegrated_prob_losses", { msr = msr("surv.logloss") - expect_silent(prediction$score(msr)) + expect_silent(pred$score(msr)) }) test_that("integrated losses with use of times", { @@ -59,28 +56,28 @@ test_that("integrated losses with use of times", { } # between 64 and 104 - test_unique_times = sort(unique(prediction$truth[,1])) - expect_true(all(test_unique_times > 63)) + test_unique_times = sort(unique(pred$truth[,1])) + expect_true(all(test_unique_times > 38)) expect_true(all(test_unique_times < 105)) # no `times` => use test set's unique time points - expect_silent(prediction$score(lapply(losses, msr, integrated = TRUE, proper = TRUE))) + expect_silent(pred$score(lapply(losses, msr, integrated = TRUE, proper = TRUE))) # all `times` outside the test set range for (loss in losses) { - expect_warning(prediction$score(msr(loss, integrated = TRUE, proper = TRUE, times = 34:38)), "requested times") + expect_warning(pred$score(msr(loss, integrated = TRUE, proper = TRUE, times = 34:38)), "requested times") } # some `times` outside the test set range for (loss in losses) { - expect_warning(prediction$score(msr(loss, integrated = TRUE, proper = TRUE, times = 100:110)), "requested times") + expect_warning(pred$score(msr(loss, integrated = TRUE, proper = TRUE, times = 100:110)), "requested times") } # one time point, inside the range, no warnings - expect_silent(prediction$score(lapply(losses, msr, integrated = FALSE, proper = TRUE, times = 80))) + expect_silent(pred$score(lapply(losses, msr, integrated = FALSE, proper = TRUE, times = 80))) }) test_that("dcalib works", { expect_equal( - pchisq(prediction$score(msr("surv.dcalib", B = 14)), df = 13, lower.tail = FALSE), - suppressWarnings(prediction$score(msr("surv.dcalib", B = 14, chisq = TRUE))) + pchisq(pred$score(msr("surv.dcalib", B = 14)), df = 13, lower.tail = FALSE), + suppressWarnings(pred$score(msr("surv.dcalib", B = 14, chisq = TRUE))) ) }) From 8298fd0cd54bf41ebe2b2aa1bae6610fa432f03d Mon Sep 17 00:00:00 2001 From: john Date: Fri, 3 Jan 2025 22:25:02 +0200 Subject: [PATCH 06/29] fix conflict (pammtools is on Suggests, do not import) --- NAMESPACE | 1 - R/zzz.R | 1 - 2 files changed, 2 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 970b713f9..7c492ea7e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -113,7 +113,6 @@ importFrom(mlr3pipelines,gunion) importFrom(mlr3pipelines,pipeline_greplicate) importFrom(mlr3pipelines,po) importFrom(mlr3pipelines,ppl) -importFrom(pammtools,as_ped) importFrom(stats,density) importFrom(stats,model.frame) importFrom(stats,model.matrix) diff --git a/R/zzz.R b/R/zzz.R index 513f32055..6b2e6c320 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -27,7 +27,6 @@ NULL #' @importFrom survival Surv #' @importFrom utils getFromNamespace #' @importFrom mlr3pipelines po as_graph %>>% pipeline_greplicate gunion Graph ppl -#' @importFrom pammtools as_ped "_PACKAGE" # nolint end From 38c8d59aaf18980bebd455689752183a693bb2a8 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 00:17:27 +0200 Subject: [PATCH 07/29] remove obsolete plotting function --- DESCRIPTION | 1 - NAMESPACE | 2 -- R/plot.R | 50 -------------------------------------------- pkgdown/_pkgdown.yml | 1 - 4 files changed, 54 deletions(-) delete mode 100644 R/plot.R diff --git a/DESCRIPTION b/DESCRIPTION index c1eb371d1..4fabb4ba8 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -170,7 +170,6 @@ Collate: 'mlr3proba-package.R' 'pecs.R' 'pipelines.R' - 'plot.R' 'plot_probregr.R' 'scoring_rule_erv.R' 'surv_measures.R' diff --git a/NAMESPACE b/NAMESPACE index 7c492ea7e..8ad5026a5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -27,7 +27,6 @@ S3method(is_missing_prediction_data,PredictionDataDens) S3method(is_missing_prediction_data,PredictionDataSurv) S3method(pecs,PredictionSurv) S3method(pecs,list) -S3method(plot,LearnerSurv) S3method(plot,TaskDens) S3method(plot,TaskSurv) export(.c_weight_survival_score) @@ -105,7 +104,6 @@ import(mlr3misc) import(paradox) importFrom(R6,R6Class) importFrom(Rcpp,sourceCpp) -importFrom(graphics,plot) importFrom(mlr3pipelines,"%>>%") importFrom(mlr3pipelines,Graph) importFrom(mlr3pipelines,as_graph) diff --git a/R/plot.R b/R/plot.R deleted file mode 100644 index ced552f53..000000000 --- a/R/plot.R +++ /dev/null @@ -1,50 +0,0 @@ -#' @title Visualization of fitted `LearnerSurv` objects -#' @description Wrapper around `predict.LearnerSurv` and `plot.Matdist`. -#' -#' @importFrom graphics plot -#' @param x ([LearnerSurv]) -#' @param task ([TaskSurv]) -#' @param fun (`character`) \cr -#' Passed to `distr6::plot.Matdist` -#' @param row_ids (`integer()`) \cr -#' Passed to `Learner$predict` -#' @param newdata (`data.frame()`) \cr -#' If not missing `Learner$predict_newdata` is called instead of `Learner$predict`. -#' @param ... Additional arguments passed to `distr6::plot.Matdist` -#' -#' -#' @examples -#' \dontrun{ -#' library(mlr3) -#' task = tsk("rats") -#' -#' # Prediction Error Curves for prediction object -#' learn = lrn("surv.coxph") -#' learn$train(task) -#' -#' plot(learn, task, "survival", ind = 10) -#' plot(learn, task, "survival", row_ids = 1:5) -#' plot(learn, task, "survival", newdata = task$data()[1:5, ]) -#' plot(learn, task, "survival", newdata = task$data()[1:5, ], ylim = c(0, 1)) -#' } -#' @export -plot.LearnerSurv = function( - x, - task, - fun = c("survival", "pdf", "cdf", "quantile", "hazard", "cumhazard"), - row_ids = NULL, - newdata, - ...) { - - fun = match.arg(fun) - - if (missing(newdata)) { - pred = x$predict(task = task, row_ids = row_ids) - } - else { - pred = x$predict_newdata(newdata = newdata, task = task) - } - - plot(pred$distr, fun = fun, ...) - -} diff --git a/pkgdown/_pkgdown.yml b/pkgdown/_pkgdown.yml index 74c4a966d..83edb11e8 100644 --- a/pkgdown/_pkgdown.yml +++ b/pkgdown/_pkgdown.yml @@ -121,7 +121,6 @@ reference: - title: Visualisation contents: - pecs - - plot.LearnerSurv - plot_probregr - starts_with("autoplot") - title: Datasets From 58fc659cae597cc672845804cbd4405501311161 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 00:18:02 +0200 Subject: [PATCH 08/29] tidy up description --- DESCRIPTION | 73 ++++++++++++++++++----------------------------------- 1 file changed, 24 insertions(+), 49 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 4fabb4ba8..13e22a11a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,49 +1,23 @@ Package: mlr3proba Title: Probabilistic Supervised Learning for 'mlr3' Version: 0.7.1 -Authors@R: - c(person(given = "Raphael", - family = "Sonabend", - role = "aut", - email = "raphaelsonabend@gmail.com", - comment = c(ORCID = "0000-0001-9225-4654")), - person(given = "Franz", - family = "Kiraly", - role = "aut", - email = "f.kiraly@ucl.ac.uk"), - person(given = "Michel", - family = "Lang", - role = "aut", - email = "michellang@gmail.com", - comment = c(ORCID = "0000-0001-9754-0393")), - person(given = "Nurul Ain", - family = "Toha", - role = "ctb", - email = "nurul.toha.15@ucl.ac.uk"), - person(given = "Andreas", - family = "Bender", - role = "ctb", - email = "bender.at.R@gmail.com", - comment = c(ORCID = "0000-0001-5628-8611")), - person(given = "John", - family = "Zobolas", - role = c("cre", "aut"), - email = "bblodfon@gmail.com", - comment = c(ORCID = "0000-0002-3609-8674")), - person(given = "Lukas", - family = "Burk", - email = "github@quantenbrot.de", - role = "ctb", - comment = c(ORCID = "0000-0001-7528-3795")), - person(given = "Philip", - family = "Studener", - role = "aut", - email = "philip.studener@gmx.de"), - person(given = "Maximilian", - family = "Muecke", - email = "muecke.maximilian@gmail.com", - role = "ctb", - comment = c(ORCID = "0009-0000-9432-9795"))) +Authors@R: c( + person("Raphael", "Sonabend", , "raphaelsonabend@gmail.com", role = "aut", + comment = c(ORCID = "0000-0001-9225-4654")), + person("Franz", "Kiraly", , "f.kiraly@ucl.ac.uk", role = "aut"), + person("Michel", "Lang", , "michellang@gmail.com", role = "aut", + comment = c(ORCID = "0000-0001-9754-0393")), + person("Nurul Ain", "Toha", , "nurul.toha.15@ucl.ac.uk", role = "ctb"), + person("Andreas", "Bender", , "bender.at.R@gmail.com", role = "ctb", + comment = c(ORCID = "0000-0001-5628-8611")), + person("John", "Zobolas", , "bblodfon@gmail.com", role = c("cre", "aut"), + comment = c(ORCID = "0000-0002-3609-8674")), + person("Lukas", "Burk", , "github@quantenbrot.de", role = "ctb", + comment = c(ORCID = "0000-0001-7528-3795")), + person("Philip", "Studener", , "philip.studener@gmx.de", role = "aut"), + person("Maximilian", "Muecke", , "muecke.maximilian@gmail.com", role = "ctb", + comment = c(ORCID = "0009-0000-9432-9795")) + ) Description: Provides extensions for probabilistic supervised learning for 'mlr3'. This includes extending the regression task to probabilistic and interval regression, adding a survival task, and other specialized @@ -67,29 +41,30 @@ Imports: Rcpp (>= 1.0.4), survival Suggests: + abind, + coxed, GGally, knitr, lgr, lifecycle, + mlr3learners, + pammtools, param6 (>= 0.2.4), + polspline, pracma, rpart, set6 (>= 0.2.6), simsurv, survAUC, - testthat (>= 3.0.0), - abind, - coxed, - mlr3learners, - pammtools + testthat (>= 3.0.0) LinkingTo: Rcpp Remotes: xoopR/distr6, xoopR/param6, xoopR/set6 -Config/testthat/edition: 3 ByteCompile: true +Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true NeedsCompilation: no From 86d46b83af6999d430593a1bcab4d0236f3f3235 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 00:19:29 +0200 Subject: [PATCH 09/29] refactor: autoplot.PredictionSurv --- R/autoplot.R | 232 ++++++++++++++++++++------------- man/autoplot.PredictionSurv.Rd | 77 ++++++----- man/autoplot.TaskDens.Rd | 2 +- man/plot.LearnerSurv.Rd | 49 ------- tests/testthat/test_autoplot.R | 6 +- 5 files changed, 193 insertions(+), 173 deletions(-) delete mode 100644 man/plot.LearnerSurv.Rd diff --git a/R/autoplot.R b/R/autoplot.R index 99543a51c..80379125a 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -42,7 +42,7 @@ #' autoplot(task, type = "duo") #' @export autoplot.TaskSurv = function(object, type = "target", theme = theme_minimal(), reverse = FALSE, ...) { # nolint - assert_string(type) + assert_choice(type, choices = c("target", "duo", "pairs"), null.ok = FALSE) require_namespaces(c("survival", "GGally")) switch(type, @@ -91,7 +91,7 @@ plot.TaskSurv = function(x, ...) { #' * `"freq"`: histogram frequency plot with [ggplot2::geom_histogram()]. #' * `"overlay"`: histogram with overlaid density plot with [ggplot2::geom_histogram()] and #' [ggplot2::geom_density()]. -#' * `"freqpoly"`: frequency polygon plot with `ggplot2::geom_freqpoly`. +#' * `"freqpoly"`: frequency polygon plot with [ggplot2::geom_freqpoly]. #' @template param_theme #' @param ... (`any`): #' Additional arguments, possibly passed down to the underlying plot functions. @@ -113,30 +113,40 @@ plot.TaskSurv = function(x, ...) { #' autoplot(task, type = "freqpoly", bins = 15) #' @export autoplot.TaskDens = function(object, type = "dens", theme = theme_minimal(), ...) { # nolint - assert_choice(type, c("dens", "freq", "overlay", "freqpoly")) + assert_choice(type, c("dens", "freq", "overlay", "freqpoly"), null.ok = FALSE) p = ggplot(data = object, aes(x = .data[[object$feature_names]]), ...) - if (type == "dens") { - p + - geom_histogram(aes(y = after_stat(density)), fill = "white", color = "black", ...) + - ylab("Density") + - theme - } else if (type == "freq") { - p + geom_histogram(fill = "white", color = "black", ...) + - ylab("Count") + - theme - } else if (type == "overlay") { - p + - geom_histogram(aes(y = after_stat(density)), colour = "black", fill = "white", ...) + - geom_density(alpha = 0.2, fill = "#5dadc8") + - ylab("Density") + - theme - } else { - p + - geom_freqpoly(...) + - theme - } + switch(type, + "dens" = { + p + + geom_histogram(aes(y = after_stat(density)), fill = "white", color = "black", ...) + + ylab("Density") + + theme + }, + + "freq" = { + p + geom_histogram(fill = "white", color = "black", ...) + + ylab("Count") + + theme + }, + + "overlay" = { + p + + geom_histogram(aes(y = after_stat(density)), colour = "black", fill = "white", ...) + + geom_density(alpha = 0.2, fill = "#5dadc8") + + ylab("Density") + + theme + }, + + "freqpoly" = { + p + + geom_freqpoly(...) + + theme + }, + + stopf("Unknown plot type '%s'", type) + ) } #' @export @@ -149,32 +159,41 @@ plot.TaskDens = function(x, ...) { #' @description #' Generates plots for [mlr3proba::PredictionSurv], depending on argument `type`: #' -#' * `"calib"` (default): Calibration plot comparing the average predicted survival distribution -#' to a Kaplan-Meier prediction, this is *not* a comparison of a stratified `crank` or `lp` -#' prediction. `object` must have `distr` prediction. `geom_line()` is used for comparison split -#' between the prediction (`Pred`) and Kaplan-Meier estimate (`KM`). In addition labels are added -#' for the x (`T`) and y (`S(T)`) axes. -#' * `"dcalib"`: Distribution calibration plot. A model is D-calibrated if X% of deaths occur before -#' the X/100 quantile of the predicted distribution, e.g. if 50% of observations die before their -#' predicted median survival time. A model is D-calibrated if the resulting plot lies on x = y. -#' * `"preds"`: Matplots the survival curves for all predictions +#' - `"calib"` (default): **Calibration plot** comparing the average predicted +#' survival distribution (`Pred`) to a Kaplan-Meier prediction (`KM`), this is +#' *not* a comparison of a stratified `crank` or `lp`. +#' - `"dcalib"`: **Distribution calibration plot**. +#' A model is considered D-calibrated if, for any given quantile `p`, the +#' proportion of observed outcomes occurring before the predicted time quantile, +#' matches `p`. For example, 50% of events should occur before the predicted +#' median survival time (i.e. the time corresponding to a predicted survival +#' probability of 0.5). +#' This means that the resulting line plot will lie close to the straight line +#' y = x. +#' Note that we impute `NA`s from the predicted quantile function with the +#' maximum observed outcome time. +#' - `"isd"`: Plot the predicted **i**ndividual **s**urvival **d**istributions +#' (survival curves) for observations from the test set. +#' +#' @section Notes: +#' +#' 1. `object` must have a `distr` prediction, as all plot `type`s use the +#' predicted survival distribution/matrix. +#' 2. `type = "dcalib"` is drawn a bit differently from Haider et al. (2020), +#' though its still conceptually the same. #' #' @param object ([mlr3proba::PredictionSurv]). -#' @template param_type -#' @param task ([mlr3proba::TaskSurv]) \cr -#' If `type = "calib"` then `task` is passed to `$predict` in the Kaplan-Meier learner. +#' @param type (`character(1)`) \cr +#' Type of the plot, see Description. #' @param row_ids (`integer()`) \cr -#' If `type = "calib"` then `row_ids` is passed to `$predict` in the Kaplan-Meier learner. +#' If `type = "isd"`, specific observation ids (from the test set) for which +#' we draw their predicted survival distributions. #' @param times (`numeric()`) \cr -#' If `type = "calib"` then `times` is the values on the x-axis to plot over, -#' if `NULL` uses all times from `task`. -#' @param xyline (`logical(1)`) \cr -#' If `TRUE` (default) plots the x-y line for `type = "dcalib"`. +#' If `type = "calib"` then `times` is the values on the x-axis to plot over. +#' if `NULL` uses all time points from the predicted survival matrix (`object$data$distr`). #' @param cuts (`integer(1)`) \cr -#' Number of cuts in (0,1) to plot `dcalib` over, default is `11`. +#' Number of cuts in \eqn{(0,1)} to plot `dcalib` over, default is `11`. #' @template param_theme -#' @param extend_quantile `(logical(1))` \cr -#' If `TRUE` then `dcalib` will impute NAs from predicted quantile function with the maximum observed outcome time, e.g. if the last predicted survival probability is greater than 0.1, then the last predicted cdf is smaller than 0.9 so F^1(0.9) = NA, this would be imputed with max(times). Default is `FALSE`. #' @param ... (`any`): #' Additional arguments, currently unused. #' @@ -193,83 +212,122 @@ plot.TaskDens = function(x, ...) { #' p = learner$train(task, row_ids = 1:300)$predict(task, row_ids = 301:400) #' #' # calibration by comparison of average prediction to Kaplan-Meier -#' autoplot(p, type = "calib", task = task, row_ids = 301:400) +#' autoplot(p) +#' +#' # same as above, use specific time points +#' autoplot(p, times = seq(1, 1000, 5)) #' #' # Distribution-calibration (D-Calibration) -#' autoplot(p, type = "dcalib", extend_quantile = TRUE) +#' autoplot(p, type = "dcalib") +#' +#' # Predicted survival curves (all observations) +#' autoplot(p, type = "isd") +#' +#' # Predicted survival curves (specific observations) +#' autoplot(p, type = "isd", row_ids = c(301, 351, 399)) #' -#' # Predictions -#' autoplot(p, type = "preds") #' @export autoplot.PredictionSurv = function(object, type = "calib", - task = NULL, row_ids = NULL, times = NULL, xyline = TRUE, - cuts = 11L, theme = theme_minimal(), extend_quantile = FALSE, ...) { - + times = NULL, row_ids = NULL, cuts = 11L, theme = theme_minimal(), ...) { + assert_choice(type, c("calib", "dcalib", "isd"), null.ok = FALSE) assert("distr" %in% object$predict_types) + assert_number(cuts, na.ok = FALSE, lower = 1L, null.ok = FALSE) + assert_numeric(row_ids, any.missing = FALSE, lower = 1, null.ok = TRUE) switch(type, "calib" = { - assert_task(task) - if (is.null(times)) { - times = sort(unique(task$truth()[, 1L])) - } - - if (inherits(object$distr, "VectorDistribution")) { - pred_surv = 1 - distr6::as.MixtureDistribution(object$distr)$cdf(times) + # get predicted survival matrix + if (inherits(object$data$distr, "array")) { + surv = object$data$distr + if (length(dim(surv)) == 3L) { + # survival 3d array, extract median + surv = .ext_surv_mat(arr = surv, which.curve = 0.5) + } } else { - pred_surv = rowMeans(1 - object$distr$cdf(times)) + stop("Distribution prediction does not have a survival matrix or array + in the $data$distr slot") } + # get predicted time points + pred_times = as.numeric(colnames(surv)) + # which time points to use for plotting + times = times %??% pred_times - km = lrn("surv.kaplan") - km_pred = km$train(task, row_ids = row_ids)$predict(task, row_ids = row_ids) - km_surv = rowMeans(1 - km_pred$distr$cdf(times)) + # function to request S(t) for points "in-between" + extend_times = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6") + # rows => times, cols => obs + surv2 = extend_times(times, pred_times, cdf = t(1 - surv), FALSE, FALSE) - data = data.frame(x = times, y = c(km_surv, pred_surv), - Group = rep(c("KM", "Pred"), each = length(times))) + # average predicted probability across test set observations + pred_surv = rowMeans(surv2) - ggplot(data, aes(x = .data[["x"]], y = .data[["y"]], group = .data[["Group"]], color = .data[["Group"]])) + + # fit a Kaplan-Meier on the test data + km_fit = survival::survfit(object$truth ~ 1) + # make a S(t) one-column matrix (by default same probability for every observation) + km_surv = matrix(km_fit$surv, ncol = 1) # rows => times + # get KM's S(t) at the predicted time points + km_surv = extend_times(times, km_fit$time, cdf = 1 - km_surv, FALSE, FALSE)[,1] + + data = data.table( + x = times, + y = c(km_surv, pred_surv), + Group = rep(c("KM", "Pred"), each = length(times)) + ) + + ggplot(data, aes(x = .data[["x"]], y = .data[["y"]], group = .data[["Group"]], + color = .data[["Group"]])) + geom_line() + - labs(x = "T", y = "S(T)") + + labs(x = "Time", y = "Average Survival Probability") + theme + theme(legend.title = element_blank()) - }, "dcalib" = { p = seq.int(0, 1, length.out = cuts) true_times = object$truth[, 1L] q = map_dbl(p, function(.x) { + # time points at which observations had `.x` survival qi = as.numeric(object$distr$quantile(.x)) - if (extend_quantile) { - qi[is.na(qi)] = max(true_times) - } + qi[is.na(qi)] = max(true_times) sum(true_times <= qi) / length(object$row_ids) }) - pl = ggplot(data = data.frame(p, q), aes(x = p, y = q)) + - geom_line() - if (xyline) { - pl = pl + - annotate("segment", x = 0, y = 0, xend = 1, yend = 1, color = "lightgray") - } - pl + - labs(x = "True", y = "Predicted") + + ggplot(data = data.table(p, q), aes(x = p, y = q)) + + geom_bar(stat = "identity", fill = "skyblue", color = "black") + + geom_line(color = "red") + + scale_x_continuous(breaks = p) + + annotate("segment", x = 0, y = 0, xend = 1, yend = 1, color = "black", + linetype = "dashed") + + labs(x = "Survival Probability (Bins)", + y = "Observed Proportion") + theme }, - "preds" = { - v = 1 - distr6::gprm(object$distr, "cdf") - surv = data.frame( - Var1 = as.factor(seq_len(nrow(v))), - Var2 = rep(as.numeric(colnames(v)), each = nrow(v)), - value = invoke(c, .args = as.data.frame(v)) + "isd" = { + surv = object$data$distr # assume this is 2d survival matrix + data = data.table( + row_id = as.factor(object$row_ids), + time = rep(as.numeric(colnames(surv)), each = nrow(surv)), + surv_prob = invoke(c, .args = as.data.table(surv)) ) - ggplot(surv, aes(x = .data[["Var2"]], y = .data[["value"]], group = .data[["Var1"]], color = .data[["Var1"]])) + + # filter data to specific ids + if (!is.null(row_ids)) { + data = data[row_id %in% row_ids] + } + + p = + ggplot(data, aes(x = .data[["time"]], y = .data[["surv_prob"]], + group = .data[["row_id"]], color = .data[["row_id"]])) + geom_line() + - labs(x = "T", y = "S(T)") + - theme + - theme(legend.position = "n") + labs(x = "Time", y = "Survival Probability") + + theme + + # usually too many observations, so don't draw legend + if (is.null(row_ids)) { + p = p + theme(legend.position = "none") + } + + p }, stopf("Unknown plot type '%s'", type) diff --git a/man/autoplot.PredictionSurv.Rd b/man/autoplot.PredictionSurv.Rd index ddad53300..7a159aaa1 100644 --- a/man/autoplot.PredictionSurv.Rd +++ b/man/autoplot.PredictionSurv.Rd @@ -7,61 +7,66 @@ \method{autoplot}{PredictionSurv}( object, type = "calib", - task = NULL, - row_ids = NULL, times = NULL, - xyline = TRUE, + row_ids = NULL, cuts = 11L, theme = theme_minimal(), - extend_quantile = FALSE, ... ) } \arguments{ \item{object}{(\link{PredictionSurv}).} -\item{type}{(\code{character(1)})\cr -Name of the column giving the type of censoring. Default is 'right' censoring.} - -\item{task}{(\link{TaskSurv}) \cr -If \code{type = "calib"} then \code{task} is passed to \verb{$predict} in the Kaplan-Meier learner.} - -\item{row_ids}{(\code{integer()}) \cr -If \code{type = "calib"} then \code{row_ids} is passed to \verb{$predict} in the Kaplan-Meier learner.} +\item{type}{(\code{character(1)}) \cr +Type of the plot, see Description.} \item{times}{(\code{numeric()}) \cr -If \code{type = "calib"} then \code{times} is the values on the x-axis to plot over, -if \code{NULL} uses all times from \code{task}.} +If \code{type = "calib"} then \code{times} is the values on the x-axis to plot over. +if \code{NULL} uses all time points from the predicted survival matrix (\code{object$data$distr}).} -\item{xyline}{(\code{logical(1)}) \cr -If \code{TRUE} (default) plots the x-y line for \code{type = "dcalib"}.} +\item{row_ids}{(\code{integer()}) \cr +If \code{type = "isd"}, specific observation ids (from the test set) for which +we draw their predicted survival distributions.} \item{cuts}{(\code{integer(1)}) \cr -Number of cuts in (0,1) to plot \code{dcalib} over, default is \code{11}.} +Number of cuts in \eqn{(0,1)} to plot \code{dcalib} over, default is \code{11}.} \item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} -\item{extend_quantile}{\code{(logical(1))} \cr -If \code{TRUE} then \code{dcalib} will impute NAs from predicted quantile function with the maximum observed outcome time, e.g. if the last predicted survival probability is greater than 0.1, then the last predicted cdf is smaller than 0.9 so F^1(0.9) = NA, this would be imputed with max(times). Default is \code{FALSE}.} - \item{...}{(\code{any}): Additional arguments, currently unused.} } \description{ Generates plots for \link{PredictionSurv}, depending on argument \code{type}: \itemize{ -\item \code{"calib"} (default): Calibration plot comparing the average predicted survival distribution -to a Kaplan-Meier prediction, this is \emph{not} a comparison of a stratified \code{crank} or \code{lp} -prediction. \code{object} must have \code{distr} prediction. \code{geom_line()} is used for comparison split -between the prediction (\code{Pred}) and Kaplan-Meier estimate (\code{KM}). In addition labels are added -for the x (\code{T}) and y (\code{S(T)}) axes. -\item \code{"dcalib"}: Distribution calibration plot. A model is D-calibrated if X\% of deaths occur before -the X/100 quantile of the predicted distribution, e.g. if 50\% of observations die before their -predicted median survival time. A model is D-calibrated if the resulting plot lies on x = y. -\item \code{"preds"}: Matplots the survival curves for all predictions +\item \code{"calib"} (default): \strong{Calibration plot} comparing the average predicted +survival distribution (\code{Pred}) to a Kaplan-Meier prediction (\code{KM}), this is +\emph{not} a comparison of a stratified \code{crank} or \code{lp}. +\item \code{"dcalib"}: \strong{Distribution calibration plot}. +A model is considered D-calibrated if, for any given quantile \code{p}, the +proportion of observed outcomes occurring before the predicted time quantile, +matches \code{p}. For example, 50\% of events should occur before the predicted +median survival time (i.e. the time corresponding to a predicted survival +probability of 0.5). +This means that the resulting line plot will lie close to the straight line +y = x. +Note that we impute \code{NA}s from the predicted quantile function with the +maximum observed outcome time. +\item \code{"isd"}: Plot the predicted \strong{i}ndividual \strong{s}urvival \strong{d}istributions +(survival curves) for observations from the test set. } } +\section{Notes}{ + +\enumerate{ +\item \code{object} must have a \code{distr} prediction, as all plot \code{type}s use the +predicted survival distribution/matrix. +\item \code{type = "dcalib"} is drawn a bit differently from Haider et al. (2020), +though its still conceptually the same. +} +} + \examples{ \dontshow{if (mlr3misc::require_namespaces(c("mlr3viz", "ggplot2"), quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} library(mlr3) @@ -73,13 +78,19 @@ task = tsk("gbcs") p = learner$train(task, row_ids = 1:300)$predict(task, row_ids = 301:400) # calibration by comparison of average prediction to Kaplan-Meier -autoplot(p, type = "calib", task = task, row_ids = 301:400) +autoplot(p) + +# same as above, use specific time points +autoplot(p, times = seq(1, 1000, 5)) # Distribution-calibration (D-Calibration) -autoplot(p, type = "dcalib", extend_quantile = TRUE) +autoplot(p, type = "dcalib") + +# Predicted survival curves (all observations) +autoplot(p, type = "isd") -# Predictions -autoplot(p, type = "preds") +# Predicted survival curves (specific observations) +autoplot(p, type = "isd", row_ids = c(301, 351, 399)) \dontshow{\}) # examplesIf} } \references{ diff --git a/man/autoplot.TaskDens.Rd b/man/autoplot.TaskDens.Rd index 68c43f813..603ef15dc 100644 --- a/man/autoplot.TaskDens.Rd +++ b/man/autoplot.TaskDens.Rd @@ -16,7 +16,7 @@ Type of the plot. Available choices: \item \code{"freq"}: histogram frequency plot with \code{\link[ggplot2:geom_histogram]{ggplot2::geom_histogram()}}. \item \code{"overlay"}: histogram with overlaid density plot with \code{\link[ggplot2:geom_histogram]{ggplot2::geom_histogram()}} and \code{\link[ggplot2:geom_density]{ggplot2::geom_density()}}. -\item \code{"freqpoly"}: frequency polygon plot with \code{ggplot2::geom_freqpoly}. +\item \code{"freqpoly"}: frequency polygon plot with \link[ggplot2:geom_histogram]{ggplot2::geom_freqpoly}. }} \item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr diff --git a/man/plot.LearnerSurv.Rd b/man/plot.LearnerSurv.Rd deleted file mode 100644 index fb0cf6bc3..000000000 --- a/man/plot.LearnerSurv.Rd +++ /dev/null @@ -1,49 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R -\name{plot.LearnerSurv} -\alias{plot.LearnerSurv} -\title{Visualization of fitted \code{LearnerSurv} objects} -\usage{ -\method{plot}{LearnerSurv}( - x, - task, - fun = c("survival", "pdf", "cdf", "quantile", "hazard", "cumhazard"), - row_ids = NULL, - newdata, - ... -) -} -\arguments{ -\item{x}{(\link{LearnerSurv})} - -\item{task}{(\link{TaskSurv})} - -\item{fun}{(\code{character}) \cr -Passed to \code{distr6::plot.Matdist}} - -\item{row_ids}{(\code{integer()}) \cr -Passed to \code{Learner$predict}} - -\item{newdata}{(\code{data.frame()}) \cr -If not missing \code{Learner$predict_newdata} is called instead of \code{Learner$predict}.} - -\item{...}{Additional arguments passed to \code{distr6::plot.Matdist}} -} -\description{ -Wrapper around \code{predict.LearnerSurv} and \code{plot.Matdist}. -} -\examples{ -\dontrun{ -library(mlr3) -task = tsk("rats") - -# Prediction Error Curves for prediction object -learn = lrn("surv.coxph") -learn$train(task) - -plot(learn, task, "survival", ind = 10) -plot(learn, task, "survival", row_ids = 1:5) -plot(learn, task, "survival", newdata = task$data()[1:5, ]) -plot(learn, task, "survival", newdata = task$data()[1:5, ], ylim = c(0, 1)) -} -} diff --git a/tests/testthat/test_autoplot.R b/tests/testthat/test_autoplot.R index c4dc90cd1..8b4df0134 100644 --- a/tests/testthat/test_autoplot.R +++ b/tests/testthat/test_autoplot.R @@ -6,13 +6,13 @@ test_that("autoplot.PredictionSurv", { learner = suppressWarnings(mlr3::lrn("surv.coxph")$train(task)) prediction = learner$predict(task) - p = autoplot(prediction, type = "calib", task = task) + p = autoplot(prediction, type = "calib") expect_true(is.ggplot(p)) - p = autoplot(prediction, type = "dcalib", extend_quantile = TRUE) + p = autoplot(prediction, type = "dcalib", cuts = 4) expect_true(is.ggplot(p)) - p = autoplot(prediction, type = "preds") + p = autoplot(prediction, type = "isd", row_ids = sample(task$row_ids, size = 5)) expect_true(is.ggplot(p)) }) From 22c56ecc8cab9dc204d3c2683495933ae43c6003 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 00:28:53 +0200 Subject: [PATCH 10/29] add ICI measure --- DESCRIPTION | 1 + NAMESPACE | 1 + R/MeasureSurvICI.R | 97 ++++++++++++++++ man/mlr_measures_surv.calib_alpha.Rd | 3 + man/mlr_measures_surv.calib_beta.Rd | 2 + man/mlr_measures_surv.calib_index.Rd | 149 +++++++++++++++++++++++++ man/mlr_measures_surv.chambless_auc.Rd | 1 + man/mlr_measures_surv.cindex.Rd | 1 + man/mlr_measures_surv.dcalib.Rd | 5 +- man/mlr_measures_surv.graf.Rd | 2 + man/mlr_measures_surv.hung_auc.Rd | 1 + man/mlr_measures_surv.intlogloss.Rd | 2 + man/mlr_measures_surv.logloss.Rd | 2 + man/mlr_measures_surv.mae.Rd | 1 + man/mlr_measures_surv.mse.Rd | 1 + man/mlr_measures_surv.nagelk_r2.Rd | 1 + man/mlr_measures_surv.oquigley_r2.Rd | 1 + man/mlr_measures_surv.rcll.Rd | 2 + man/mlr_measures_surv.rmse.Rd | 1 + man/mlr_measures_surv.schmid.Rd | 2 + man/mlr_measures_surv.song_auc.Rd | 1 + man/mlr_measures_surv.song_tnr.Rd | 1 + man/mlr_measures_surv.song_tpr.Rd | 1 + man/mlr_measures_surv.uno_auc.Rd | 1 + man/mlr_measures_surv.uno_tnr.Rd | 1 + man/mlr_measures_surv.uno_tpr.Rd | 1 + man/mlr_measures_surv.xu_r2.Rd | 1 + 27 files changed, 282 insertions(+), 1 deletion(-) create mode 100644 R/MeasureSurvICI.R create mode 100644 man/mlr_measures_surv.calib_index.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 13e22a11a..dffcc2a9b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -91,6 +91,7 @@ Collate: 'MeasureSurvDCalibration.R' 'MeasureSurvGraf.R' 'MeasureSurvHungAUC.R' + 'MeasureSurvICI.R' 'MeasureSurvIntLogloss.R' 'MeasureSurvLogloss.R' 'MeasureSurvMAE.R' diff --git a/NAMESPACE b/NAMESPACE index 8ad5026a5..39c79eec5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -50,6 +50,7 @@ export(MeasureSurvCindex) export(MeasureSurvDCalibration) export(MeasureSurvGraf) export(MeasureSurvHungAUC) +export(MeasureSurvICI) export(MeasureSurvIntLogloss) export(MeasureSurvLogloss) export(MeasureSurvMAE) diff --git a/R/MeasureSurvICI.R b/R/MeasureSurvICI.R new file mode 100644 index 000000000..3435e0b4a --- /dev/null +++ b/R/MeasureSurvICI.R @@ -0,0 +1,97 @@ +#' @template surv_measure +#' @templateVar title Integrated Calibration Index +#' @templateVar fullname MeasureSurvICI +#' @templateVar eps 1e-4 +#' @template param_eps +#' +#' @description +#' Calculates the Integrated Calibration Index (ICI), which is the absolute +#' difference between predicted survival probabilities and smoothed survival +#' frequencies (calculated using hazard regression via the \CRANpkg{polspline}) +#' at a specific time point. +#' +#' @references +#' `r format_bib("austin2020")` +#' +#' @family calibration survival measures +#' @family distr survival measures +#' @export +MeasureSurvICI = R6Class("MeasureSurvICI", + inherit = MeasureSurv, + public = list( + #' @description Creates a new instance of this [R6][R6::R6Class] class. + initialize = function() { + param_set = ps( + time = p_dbl(0, Inf), + method = p_fct(default = "ICI", levels = c("ICI", "E50", "E90", "Emax")) + ) + param_set$set_values(method = "ICI") + + super$initialize( + id = "surv.calib_index", + packages = c("polspline"), + range = c(0, Inf), + minimize = TRUE, + predict_type = "distr", + label = "Integrated Calibration Index", + man = "mlr3proba::mlr_measures_surv.calib_index", + param_set = param_set + ) + } + ), + + private = list( + .score = function(prediction, ...) { + # test set survival outcome + times = prediction$truth[, 1L] + status = prediction$truth[, 2L] + + # get predicted survival matrix + if (inherits(prediction$data$distr, "array")) { + surv = prediction$data$distr + if (length(dim(surv)) == 3L) { + # survival 3d array, extract median + surv = .ext_surv_mat(arr = surv, which.curve = 0.5) + } + } else { + stop("Distribution prediction does not have a survival matrix or array + in the $data$distr slot") + } + + # time point for calibration + time = self$param_set$values$time %??% median(times) + + # get cdf at the specified time point + extend_times_cdf = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6") + pred_times = as.numeric(colnames(surv)) + cdf = as.vector(extend_times_cdf(time, pred_times, cdf = t(1 - surv), TRUE, FALSE)) + #browser() + # cdf 1 => 0.9999 + + # get the cdf complement (survival) log-log transformed + llsurv = log(-log(1 - cdf)) + + hare_fit = polspline::hare(data = times, delta = status, cov = as.matrix(llsurv)) + cdf_hare = polspline::phare(q = time, cov = llsurv, fit = hare_fit) + + method = self$param_set$values$method + if (method == "ICI") { + # Mean difference (ICI) + result = mean(abs(cdf - cdf_hare)) + } else if (method == "E50") { + # Median (E50) + result = median(abs(cdf - cdf_hare)) + } else if (method == "E90") { + # 90th percentile (E90) + result = quantile(abs(cdf - cdf_hare), probs = 0.90) + } else if (method == "Emax") { + # Maximum absolute difference (Emax) + result = max(abs(cdf - cdf_hare)) + } + + return(result) + } + ) +) + +register_measure("surv.calib_index", MeasureSurvICI) diff --git a/man/mlr_measures_surv.calib_alpha.Rd b/man/mlr_measures_surv.calib_alpha.Rd index df8fdceb2..66301020b 100644 --- a/man/mlr_measures_surv.calib_alpha.Rd +++ b/man/mlr_measures_surv.calib_alpha.Rd @@ -85,6 +85,7 @@ Van Houwelingen, C. H (2000). \seealso{ Other survival measures: \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, @@ -109,9 +110,11 @@ Other survival measures: Other calibration survival measures: \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.dcalib}} Other distr survival measures: +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.dcalib}}, \code{\link{mlr_measures_surv.graf}}, \code{\link{mlr_measures_surv.intlogloss}}, diff --git a/man/mlr_measures_surv.calib_beta.Rd b/man/mlr_measures_surv.calib_beta.Rd index b0aa2a6b5..25309c338 100644 --- a/man/mlr_measures_surv.calib_beta.Rd +++ b/man/mlr_measures_surv.calib_beta.Rd @@ -70,6 +70,7 @@ Van Houwelingen, C. H (2000). \seealso{ Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, @@ -94,6 +95,7 @@ Other survival measures: Other calibration survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.dcalib}} Other lp survival measures: diff --git a/man/mlr_measures_surv.calib_index.Rd b/man/mlr_measures_surv.calib_index.Rd new file mode 100644 index 000000000..d76a0d191 --- /dev/null +++ b/man/mlr_measures_surv.calib_index.Rd @@ -0,0 +1,149 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/MeasureSurvICI.R +\name{mlr_measures_surv.calib_index} +\alias{mlr_measures_surv.calib_index} +\alias{MeasureSurvICI} +\title{Integrated Calibration Index Survival Measure} +\description{ +Calculates the Integrated Calibration Index (ICI), which is the absolute +difference between predicted survival probabilities and smoothed survival +frequencies (calculated using hazard regression via the \CRANpkg{polspline}) +at a specific time point. +} +\section{Dictionary}{ + +This \link[mlr3:Measure]{Measure} can be instantiated via the \link[mlr3misc:Dictionary]{dictionary} +\link[mlr3:mlr_measures]{mlr_measures} or with the associated sugar function \link[mlr3:mlr_sugar]{msr()}: + +\if{html}{\out{
}}\preformatted{MeasureSurvICI$new() +mlr_measures$get("surv.calib_index") +msr("surv.calib_index") +}\if{html}{\out{
}} +} + +\section{Parameters}{ +\tabular{lllll}{ + Id \tab Type \tab Default \tab Levels \tab Range \cr + time \tab numeric \tab - \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + method \tab character \tab ICI \tab ICI, E50, E90, Emax \tab - \cr +} +} + +\section{Meta Information}{ + +\itemize{ +\item Type: \code{"surv"} +\item Range: \eqn{[0, \infty)}{[0, Inf)} +\item Minimize: \code{TRUE} +\item Required prediction: \code{distr} +} +} + +\section{Parameter details}{ + +\itemize{ +\item \code{eps} (\code{numeric(1)})\cr +Very small number to substitute zero values in order to prevent errors +in e.g. log(0) and/or division-by-zero calculations. +Default value is 1e-04. +} +} + +\references{ +Austin, C. P, Harrell, E. F, van Klaveren, David (2020). +\dQuote{Graphical calibration curves and the integrated calibration index (ICI) for survival models.} +\emph{Statistics in Medicine}, \bold{39}(21), 2714. +ISSN 10970258, \doi{10.1002/SIM.8570}, \url{https://pmc.ncbi.nlm.nih.gov/articles/PMC7497089/}. +} +\seealso{ +Other survival measures: +\code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.chambless_auc}}, +\code{\link{mlr_measures_surv.cindex}}, +\code{\link{mlr_measures_surv.dcalib}}, +\code{\link{mlr_measures_surv.graf}}, +\code{\link{mlr_measures_surv.hung_auc}}, +\code{\link{mlr_measures_surv.intlogloss}}, +\code{\link{mlr_measures_surv.logloss}}, +\code{\link{mlr_measures_surv.mae}}, +\code{\link{mlr_measures_surv.mse}}, +\code{\link{mlr_measures_surv.nagelk_r2}}, +\code{\link{mlr_measures_surv.oquigley_r2}}, +\code{\link{mlr_measures_surv.rcll}}, +\code{\link{mlr_measures_surv.rmse}}, +\code{\link{mlr_measures_surv.schmid}}, +\code{\link{mlr_measures_surv.song_auc}}, +\code{\link{mlr_measures_surv.song_tnr}}, +\code{\link{mlr_measures_surv.song_tpr}}, +\code{\link{mlr_measures_surv.uno_auc}}, +\code{\link{mlr_measures_surv.uno_tnr}}, +\code{\link{mlr_measures_surv.uno_tpr}}, +\code{\link{mlr_measures_surv.xu_r2}} + +Other calibration survival measures: +\code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.dcalib}} + +Other distr survival measures: +\code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.dcalib}}, +\code{\link{mlr_measures_surv.graf}}, +\code{\link{mlr_measures_surv.intlogloss}}, +\code{\link{mlr_measures_surv.logloss}}, +\code{\link{mlr_measures_surv.rcll}}, +\code{\link{mlr_measures_surv.schmid}} +} +\concept{calibration survival measures} +\concept{distr survival measures} +\concept{survival measures} +\section{Super classes}{ +\code{\link[mlr3:Measure]{mlr3::Measure}} -> \code{\link[mlr3proba:MeasureSurv]{mlr3proba::MeasureSurv}} -> \code{MeasureSurvICI} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-MeasureSurvICI-new}{\code{MeasureSurvICI$new()}} +\item \href{#method-MeasureSurvICI-clone}{\code{MeasureSurvICI$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-MeasureSurvICI-new}{}}} +\subsection{Method \code{new()}}{ +Creates a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{MeasureSurvICI$new()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-MeasureSurvICI-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{MeasureSurvICI$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/mlr_measures_surv.chambless_auc.Rd b/man/mlr_measures_surv.chambless_auc.Rd index ff361021d..67a517499 100644 --- a/man/mlr_measures_surv.chambless_auc.Rd +++ b/man/mlr_measures_surv.chambless_auc.Rd @@ -100,6 +100,7 @@ Chambless LE, Diao G (2006). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, \code{\link{mlr_measures_surv.graf}}, diff --git a/man/mlr_measures_surv.cindex.Rd b/man/mlr_measures_surv.cindex.Rd index 9fba1e787..6aea658b0 100644 --- a/man/mlr_measures_surv.cindex.Rd +++ b/man/mlr_measures_surv.cindex.Rd @@ -138,6 +138,7 @@ Uno H, Cai T, Pencina MJ, D'Agostino RB, Wei LJ (2011). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.dcalib}}, \code{\link{mlr_measures_surv.graf}}, diff --git a/man/mlr_measures_surv.dcalib.Rd b/man/mlr_measures_surv.dcalib.Rd index d096a8c35..7595f9f5c 100644 --- a/man/mlr_measures_surv.dcalib.Rd +++ b/man/mlr_measures_surv.dcalib.Rd @@ -95,6 +95,7 @@ Haider, Humza, Hoehn, Bret, Davis, Sarah, Greiner, Russell (2020). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.graf}}, @@ -118,10 +119,12 @@ Other survival measures: Other calibration survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, -\code{\link{mlr_measures_surv.calib_beta}} +\code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}} Other distr survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.graf}}, \code{\link{mlr_measures_surv.intlogloss}}, \code{\link{mlr_measures_surv.logloss}}, diff --git a/man/mlr_measures_surv.graf.Rd b/man/mlr_measures_surv.graf.Rd index f0e7f59bd..0131409c1 100644 --- a/man/mlr_measures_surv.graf.Rd +++ b/man/mlr_measures_surv.graf.Rd @@ -308,6 +308,7 @@ ISSN 1533-7928, \url{http://jmlr.org/papers/v24/19-1030.html}. Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, @@ -337,6 +338,7 @@ Other Probabilistic survival measures: Other distr survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.dcalib}}, \code{\link{mlr_measures_surv.intlogloss}}, \code{\link{mlr_measures_surv.logloss}}, diff --git a/man/mlr_measures_surv.hung_auc.Rd b/man/mlr_measures_surv.hung_auc.Rd index 8bba42e1d..9a813c172 100644 --- a/man/mlr_measures_surv.hung_auc.Rd +++ b/man/mlr_measures_surv.hung_auc.Rd @@ -100,6 +100,7 @@ Hung H, Chiang C (2010). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.intlogloss.Rd b/man/mlr_measures_surv.intlogloss.Rd index a2078837c..c76e2296f 100644 --- a/man/mlr_measures_surv.intlogloss.Rd +++ b/man/mlr_measures_surv.intlogloss.Rd @@ -306,6 +306,7 @@ ISSN 1533-7928, \url{http://jmlr.org/papers/v24/19-1030.html}. Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, @@ -335,6 +336,7 @@ Other Probabilistic survival measures: Other distr survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.dcalib}}, \code{\link{mlr_measures_surv.graf}}, \code{\link{mlr_measures_surv.logloss}}, diff --git a/man/mlr_measures_surv.logloss.Rd b/man/mlr_measures_surv.logloss.Rd index bc7241647..5f21eb745 100644 --- a/man/mlr_measures_surv.logloss.Rd +++ b/man/mlr_measures_surv.logloss.Rd @@ -111,6 +111,7 @@ Sonabend, Raphael, Zobolas, John, Kopper, Philipp, Burk, Lukas, Bender, Andreas Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, @@ -140,6 +141,7 @@ Other Probabilistic survival measures: Other distr survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.dcalib}}, \code{\link{mlr_measures_surv.graf}}, \code{\link{mlr_measures_surv.intlogloss}}, diff --git a/man/mlr_measures_surv.mae.Rd b/man/mlr_measures_surv.mae.Rd index 82ab5ade0..d68832849 100644 --- a/man/mlr_measures_surv.mae.Rd +++ b/man/mlr_measures_surv.mae.Rd @@ -56,6 +56,7 @@ Default is \code{FALSE} (returns the mean). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.mse.Rd b/man/mlr_measures_surv.mse.Rd index 68821e010..a641b82b5 100644 --- a/man/mlr_measures_surv.mse.Rd +++ b/man/mlr_measures_surv.mse.Rd @@ -56,6 +56,7 @@ Default is \code{FALSE} (returns the mean). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.nagelk_r2.Rd b/man/mlr_measures_surv.nagelk_r2.Rd index b2cc1dd2f..fa9fa21de 100644 --- a/man/mlr_measures_surv.nagelk_r2.Rd +++ b/man/mlr_measures_surv.nagelk_r2.Rd @@ -51,6 +51,7 @@ Nagelkerke, JD N, others (1991). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.oquigley_r2.Rd b/man/mlr_measures_surv.oquigley_r2.Rd index 7a614a9a4..0e959c50e 100644 --- a/man/mlr_measures_surv.oquigley_r2.Rd +++ b/man/mlr_measures_surv.oquigley_r2.Rd @@ -52,6 +52,7 @@ O'Quigley J, Xu R, Stare J (2005). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.rcll.Rd b/man/mlr_measures_surv.rcll.Rd index 1a1a4b47a..8b0ba75db 100644 --- a/man/mlr_measures_surv.rcll.Rd +++ b/man/mlr_measures_surv.rcll.Rd @@ -97,6 +97,7 @@ Rindt, David, Hu, Robert, Steinsaltz, David, Sejdinovic, Dino (2022). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, @@ -126,6 +127,7 @@ Other Probabilistic survival measures: Other distr survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.dcalib}}, \code{\link{mlr_measures_surv.graf}}, \code{\link{mlr_measures_surv.intlogloss}}, diff --git a/man/mlr_measures_surv.rmse.Rd b/man/mlr_measures_surv.rmse.Rd index 89dc8b8cc..2b629020a 100644 --- a/man/mlr_measures_surv.rmse.Rd +++ b/man/mlr_measures_surv.rmse.Rd @@ -56,6 +56,7 @@ Default is \code{FALSE} (returns the mean). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.schmid.Rd b/man/mlr_measures_surv.schmid.Rd index 7dade0c4a..bebe5201e 100644 --- a/man/mlr_measures_surv.schmid.Rd +++ b/man/mlr_measures_surv.schmid.Rd @@ -310,6 +310,7 @@ ISSN 1533-7928, \url{http://jmlr.org/papers/v24/19-1030.html}. Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, @@ -339,6 +340,7 @@ Other Probabilistic survival measures: Other distr survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.dcalib}}, \code{\link{mlr_measures_surv.graf}}, \code{\link{mlr_measures_surv.intlogloss}}, diff --git a/man/mlr_measures_surv.song_auc.Rd b/man/mlr_measures_surv.song_auc.Rd index e2f2bb0aa..65e92d9fc 100644 --- a/man/mlr_measures_surv.song_auc.Rd +++ b/man/mlr_measures_surv.song_auc.Rd @@ -108,6 +108,7 @@ Song, Xiao, Zhou, Xiao-Hua (2008). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.song_tnr.Rd b/man/mlr_measures_surv.song_tnr.Rd index aa2717532..4ce33b2ec 100644 --- a/man/mlr_measures_surv.song_tnr.Rd +++ b/man/mlr_measures_surv.song_tnr.Rd @@ -74,6 +74,7 @@ Song, Xiao, Zhou, Xiao-Hua (2008). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.song_tpr.Rd b/man/mlr_measures_surv.song_tpr.Rd index 62549705a..edd507e09 100644 --- a/man/mlr_measures_surv.song_tpr.Rd +++ b/man/mlr_measures_surv.song_tpr.Rd @@ -75,6 +75,7 @@ Song, Xiao, Zhou, Xiao-Hua (2008). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.uno_auc.Rd b/man/mlr_measures_surv.uno_auc.Rd index 332082fbb..46818c0b5 100644 --- a/man/mlr_measures_surv.uno_auc.Rd +++ b/man/mlr_measures_surv.uno_auc.Rd @@ -100,6 +100,7 @@ Uno H, Cai T, Tian L, Wei LJ (2007). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.uno_tnr.Rd b/man/mlr_measures_surv.uno_tnr.Rd index e780be55c..a81379344 100644 --- a/man/mlr_measures_surv.uno_tnr.Rd +++ b/man/mlr_measures_surv.uno_tnr.Rd @@ -73,6 +73,7 @@ Uno H, Cai T, Tian L, Wei LJ (2007). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.uno_tpr.Rd b/man/mlr_measures_surv.uno_tpr.Rd index c881876f3..abde862a9 100644 --- a/man/mlr_measures_surv.uno_tpr.Rd +++ b/man/mlr_measures_surv.uno_tpr.Rd @@ -73,6 +73,7 @@ Uno H, Cai T, Tian L, Wei LJ (2007). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, diff --git a/man/mlr_measures_surv.xu_r2.Rd b/man/mlr_measures_surv.xu_r2.Rd index 303c9c44d..23d0c2d1f 100644 --- a/man/mlr_measures_surv.xu_r2.Rd +++ b/man/mlr_measures_surv.xu_r2.Rd @@ -52,6 +52,7 @@ Xu R, O'Quigley J (1999). Other survival measures: \code{\link{mlr_measures_surv.calib_alpha}}, \code{\link{mlr_measures_surv.calib_beta}}, +\code{\link{mlr_measures_surv.calib_index}}, \code{\link{mlr_measures_surv.chambless_auc}}, \code{\link{mlr_measures_surv.cindex}}, \code{\link{mlr_measures_surv.dcalib}}, From 108debc65929d0deb487cf56ca44ce63b34cc144 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 10:40:31 +0200 Subject: [PATCH 11/29] avoid use of fortify() by giving directly the data.table object --- R/autoplot.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 80379125a..e34854cb6 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -59,14 +59,14 @@ autoplot.TaskSurv = function(object, type = "target", theme = theme_minimal(), r }, "duo" = { - GGally::ggduo(object, + GGally::ggduo(object$data(), columnsX = object$target_names, columnsY = object$feature_names, ...) + theme }, "pairs" = { - GGally::ggpairs(object, ...) + + GGally::ggpairs(object$data(), ...) + theme }, @@ -115,7 +115,7 @@ plot.TaskSurv = function(x, ...) { autoplot.TaskDens = function(object, type = "dens", theme = theme_minimal(), ...) { # nolint assert_choice(type, c("dens", "freq", "overlay", "freqpoly"), null.ok = FALSE) - p = ggplot(data = object, aes(x = .data[[object$feature_names]]), ...) + p = ggplot(data = object$data(), aes(x = .data[[object$feature_names]]), ...) switch(type, "dens" = { From c205649886c19527c9ccbfd1fb13f6cd5b47d7d8 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 10:41:13 +0200 Subject: [PATCH 12/29] refine autoplot tests --- tests/testthat/test_autoplot.R | 38 +++++++++++++++------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/tests/testthat/test_autoplot.R b/tests/testthat/test_autoplot.R index 8b4df0134..bf0f86fae 100644 --- a/tests/testthat/test_autoplot.R +++ b/tests/testthat/test_autoplot.R @@ -1,8 +1,19 @@ -test_that("autoplot.PredictionSurv", { - skip_if_not_installed("mlr3proba") - require_namespaces("mlr3proba") +skip_if_not_installed("mlr3proba") +require_namespaces("mlr3proba") +task = tsk("rats") + +test_that("autoplot.TaskSurv", { + p = autoplot(task, type = "target") + expect_true(is.ggplot(p)) + + p = autoplot(task, type = "pairs") + expect_s3_class(p, "ggmatrix") + + p = autoplot(task, type = "duo") + expect_s3_class(p, "ggmatrix") +}) - task = mlr3::tsk("rats")$filter(1:100) +test_that("autoplot.PredictionSurv", { learner = suppressWarnings(mlr3::lrn("surv.coxph")$train(task)) prediction = learner$predict(task) @@ -19,7 +30,8 @@ test_that("autoplot.PredictionSurv", { test_that("autoplot.TaskDens", { skip_if_not_installed("mlr3proba") require_namespaces("mlr3proba") - task = mlr3::tsk("precip") + + task = tsk("precip") p = autoplot(task, type = "dens") expect_true(is.ggplot(p)) @@ -33,19 +45,3 @@ test_that("autoplot.TaskDens", { p = autoplot(task, type = "freqpoly") expect_true(is.ggplot(p)) }) - -test_that("autoplot.TaskSurv", { - skip_if_not_installed("mlr3proba") - - require_namespaces("mlr3proba") - task = mlr3::tsk("rats") - - p = autoplot(task, type = "target") - expect_true(is.ggplot(p)) - - p = autoplot(task, type = "pairs") - expect_s3_class(p, "ggmatrix") - - p = autoplot(task, type = "duo") - expect_s3_class(p, "ggmatrix") -}) From efdcc2612e202c8d3a995389d3151837fe5bc104 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 10:51:04 +0200 Subject: [PATCH 13/29] fix NOTEs --- R/MeasureSurvICI.R | 4 ++-- R/autoplot.R | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/MeasureSurvICI.R b/R/MeasureSurvICI.R index 3435e0b4a..08fda8a75 100644 --- a/R/MeasureSurvICI.R +++ b/R/MeasureSurvICI.R @@ -80,10 +80,10 @@ MeasureSurvICI = R6Class("MeasureSurvICI", result = mean(abs(cdf - cdf_hare)) } else if (method == "E50") { # Median (E50) - result = median(abs(cdf - cdf_hare)) + result = stats::median(abs(cdf - cdf_hare)) } else if (method == "E90") { # 90th percentile (E90) - result = quantile(abs(cdf - cdf_hare), probs = 0.90) + result = stats::quantile(abs(cdf - cdf_hare), probs = 0.90) } else if (method == "Emax") { # Maximum absolute difference (Emax) result = max(abs(cdf - cdf_hare)) diff --git a/R/autoplot.R b/R/autoplot.R index e34854cb6..921f06144 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -312,7 +312,7 @@ autoplot.PredictionSurv = function(object, type = "calib", # filter data to specific ids if (!is.null(row_ids)) { - data = data[row_id %in% row_ids] + data = data[get("row_id") %in% row_ids] } p = From 3bd1c3d7f60db60f11f93b08e3b084ae948d1726 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 15:38:59 +0200 Subject: [PATCH 14/29] better doc, add example, refactor --- R/MeasureSurvICI.R | 84 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 16 deletions(-) diff --git a/R/MeasureSurvICI.R b/R/MeasureSurvICI.R index 08fda8a75..de2b35af0 100644 --- a/R/MeasureSurvICI.R +++ b/R/MeasureSurvICI.R @@ -1,20 +1,72 @@ #' @template surv_measure #' @templateVar title Integrated Calibration Index #' @templateVar fullname MeasureSurvICI -#' @templateVar eps 1e-4 -#' @template param_eps #' #' @description -#' Calculates the Integrated Calibration Index (ICI), which is the absolute -#' difference between predicted survival probabilities and smoothed survival -#' frequencies (calculated using hazard regression via the \CRANpkg{polspline}) -#' at a specific time point. +#' Calculates the Integrated Calibration Index (ICI), see Austin et al. (2020). +#' +#' @details +#' Each individual \eqn{i} from the test set, has an observed survival outcome +#' \eqn{(t_i, \delta_i)} (time and censoring indicator) and predicted survival +#' function \eqn{S_i(t)}. +#' The predicted probability of an event occurring before a specific time point +#' \eqn{t_0}, is defined as \eqn{\hat{P_i}(t_0) = F_i(t_0) = 1 - S_i(t_0)}. +#' +#' Using hazard regression (via the \CRANpkg{polspline} R package), a *smoothed* +#' calibration curve is estimated by fitting the following model: +#' \deqn{log(h(t)) = g(log(− log(1 − \hat{P}_{t_0})), t)} +#' +#' Any \eqn{\hat{P}_{t_0} = 1} is set to \eqn{0.9999} to avoid calculating \eqn{log(0)}. +#' From this model, the *smoothed* probability of occurrence at \eqn{t_0} for +#' observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}. +#' +#' The **Integrated Calibration Index** is then computed as: +#' \deqn{ICI = \frac{1}{N} \sum_{i=1}^N | \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |} +#' +#' This measure evaluates **point-calibration** at a specific time point, which +#' must be specified by the user. +#' +#' @section Parameter details: +#' - `time` (`numeric(1)`)\cr +#' The specific time point \eqn{t_0} at which calibration is evaluated. +#' If `NULL`, the median observed time from the test set is used. +#' - `method` (`character(1)`)\cr +#' Specifies the summary statistic used to calculate the final calibration score. +#' - `"ICI"` (default): Uses the mean of absolute differences \eqn{| \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |} across all observations. +#' - `"E50"`: Uses the median of absolute differences instead of the mean. +#' - `"E90"`: Uses the 90th percentile of absolute differences, emphasizing higher deviations. +#' - `"Emax"`: Uses the maximum absolute difference, capturing the largest discrepancy between predicted and smoothed probabilities. #' #' @references #' `r format_bib("austin2020")` #' #' @family calibration survival measures #' @family distr survival measures +#' @examples +#' library(mlr3) +#' +#' # Define a survival Task +#' task = tsk("lung") +#' +#' # Create train and test set +#' part = partition(task) +#' +#' # Train Cox learner on the train set +#' cox = lrn("surv.coxph") +#' cox$train(task, row_ids = part$train) +#' +#' # Make predictions for the test set +#' p = cox$predict(task, row_ids = part$test) +#' +#' # ICI at median test time +#' p$score(msr("surv.calib_index")) +#' +#' # ICI at specific time point +#' p$score(msr("surv.calib_index", time = 365)) +#' +#' # E50 at specific time point +#' p$score(msr("surv.calib_index", method = "E50", time = 365)) +#' #' @export MeasureSurvICI = R6Class("MeasureSurvICI", inherit = MeasureSurv, @@ -43,7 +95,7 @@ MeasureSurvICI = R6Class("MeasureSurvICI", private = list( .score = function(prediction, ...) { # test set survival outcome - times = prediction$truth[, 1L] + times = prediction$truth[, 1L] status = prediction$truth[, 2L] # get predicted survival matrix @@ -65,28 +117,28 @@ MeasureSurvICI = R6Class("MeasureSurvICI", extend_times_cdf = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6") pred_times = as.numeric(colnames(surv)) cdf = as.vector(extend_times_cdf(time, pred_times, cdf = t(1 - surv), TRUE, FALSE)) - #browser() - # cdf 1 => 0.9999 + # to avoid log(0) later, same as in paper's Appendix + cdf[cdf == 1] = 0.9999 # get the cdf complement (survival) log-log transformed - llsurv = log(-log(1 - cdf)) + cll = log(-log(1 - cdf)) - hare_fit = polspline::hare(data = times, delta = status, cov = as.matrix(llsurv)) - cdf_hare = polspline::phare(q = time, cov = llsurv, fit = hare_fit) + hare_fit = polspline::hare(data = times, delta = status, cov = as.matrix(cll)) + smoothed_cdf = polspline::phare(q = time, cov = cll, fit = hare_fit) method = self$param_set$values$method if (method == "ICI") { # Mean difference (ICI) - result = mean(abs(cdf - cdf_hare)) + result = mean(abs(cdf - smoothed_cdf)) } else if (method == "E50") { # Median (E50) - result = stats::median(abs(cdf - cdf_hare)) + result = stats::median(abs(cdf - smoothed_cdf)) } else if (method == "E90") { # 90th percentile (E90) - result = stats::quantile(abs(cdf - cdf_hare), probs = 0.90) + result = stats::quantile(abs(cdf - smoothed_cdf), probs = 0.90) } else if (method == "Emax") { # Maximum absolute difference (Emax) - result = max(abs(cdf - cdf_hare)) + result = max(abs(cdf - smoothed_cdf)) } return(result) From 5458de9f47b4e95ad5178982b876ab8a535284f7 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 15:39:12 +0200 Subject: [PATCH 15/29] add test for ICI --- tests/testthat/test_mlr_measures.R | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/testthat/test_mlr_measures.R b/tests/testthat/test_mlr_measures.R index ea8b67403..4e103c971 100644 --- a/tests/testthat/test_mlr_measures.R +++ b/tests/testthat/test_mlr_measures.R @@ -119,6 +119,28 @@ test_that("calib_alpha works", { expect_equal(unname(pred$score(m3)), -1) }) +test_that("calib_index works", { + m = msr("surv.calib_index") + expect_equal(m$range, c(0, Inf)) + expect_true(m$minimize) + expect_equal(m$param_set$values$method, "ICI") + res = pred$score(m) + expect_gt(res, 0) + + m2 = msr("surv.calib_index", method = "E90") + expect_equal(m2$param_set$values$method, "E90") + res2 = pred$score(m2) + expect_gt(res2, res) + + m3 = msr("surv.calib_index", method = "Emax") + expect_equal(m3$param_set$values$method, "Emax") + expect_gt(pred$score(m3), res2) + + m4 = msr("surv.calib_index", time = 100) + expect_equal(m4$param_set$values$time, 100) + expect_false(pred$score(m4) == res) +}) + test_that("graf training data for weights", { m = msr("surv.graf", proper = TRUE) t = tsk("rats") From d48c44c362d324e969a98d42f92a3e2f6e34a513 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 15:40:23 +0200 Subject: [PATCH 16/29] updocs --- man/mlr_measures_surv.calib_index.Rd | 67 ++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 8 deletions(-) diff --git a/man/mlr_measures_surv.calib_index.Rd b/man/mlr_measures_surv.calib_index.Rd index d76a0d191..6e24a9bf8 100644 --- a/man/mlr_measures_surv.calib_index.Rd +++ b/man/mlr_measures_surv.calib_index.Rd @@ -5,10 +5,28 @@ \alias{MeasureSurvICI} \title{Integrated Calibration Index Survival Measure} \description{ -Calculates the Integrated Calibration Index (ICI), which is the absolute -difference between predicted survival probabilities and smoothed survival -frequencies (calculated using hazard regression via the \CRANpkg{polspline}) -at a specific time point. +Calculates the Integrated Calibration Index (ICI), see Austin et al. (2020). +} +\details{ +Each individual \eqn{i} from the test set, has an observed survival outcome +\eqn{(t_i, \delta_i)} (time and censoring indicator) and predicted survival +function \eqn{S_i(t)}. +The predicted probability of an event occurring before a specific time point +\eqn{t_0}, is defined as \eqn{\hat{P_i}(t_0) = F_i(t_0) = 1 - S_i(t_0)}. + +Using hazard regression (via the \CRANpkg{polspline} R package), a \emph{smoothed} +calibration curve is estimated by fitting the following model: +\deqn{log(h(t)) = g(log(− log(1 − \hat{P}_{t_0})), t)} + +Any \eqn{\hat{P}_{t_0} = 1} is set to \eqn{0.9999} to avoid calculating \eqn{log(0)}. +From this model, the \emph{smoothed} probability of occurrence at \eqn{t_0} for +observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}. + +The \strong{Integrated Calibration Index} is then computed as: +\deqn{ICI = \frac{1}{N} \sum_{i=1}^N | \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |} + +This measure evaluates \strong{point-calibration} at a specific time point, which +must be specified by the user. } \section{Dictionary}{ @@ -42,13 +60,46 @@ msr("surv.calib_index") \section{Parameter details}{ \itemize{ -\item \code{eps} (\code{numeric(1)})\cr -Very small number to substitute zero values in order to prevent errors -in e.g. log(0) and/or division-by-zero calculations. -Default value is 1e-04. +\item \code{time} (\code{numeric(1)})\cr +The specific time point \eqn{t_0} at which calibration is evaluated. +If \code{NULL}, the median observed time from the test set is used. +\item \code{method} (\code{character(1)})\cr +Specifies the summary statistic used to calculate the final calibration score. +\itemize{ +\item \code{"ICI"} (default): Uses the mean of absolute differences \eqn{| \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |} across all observations. +\item \code{"E50"}: Uses the median of absolute differences instead of the mean. +\item \code{"E90"}: Uses the 90th percentile of absolute differences, emphasizing higher deviations. +\item \code{"Emax"}: Uses the maximum absolute difference, capturing the largest discrepancy between predicted and smoothed probabilities. +} } } +\examples{ +library(mlr3) + +# Define a survival Task +task = tsk("lung") + +# Create train and test set +part = partition(task) + +# Train Cox learner on the train set +cox = lrn("surv.coxph") +cox$train(task, row_ids = part$train) + +# Make predictions for the test set +p = cox$predict(task, row_ids = part$test) + +# ICI at median test time +p$score(msr("surv.calib_index")) + +# ICI at specific time point +p$score(msr("surv.calib_index", time = 365)) + +# E50 at specific time point +p$score(msr("surv.calib_index", method = "E50", time = 365)) + +} \references{ Austin, C. P, Harrell, E. F, van Klaveren, David (2020). \dQuote{Graphical calibration curves and the integrated calibration index (ICI) for survival models.} From 7d89a3094701f36c2a94d9196dbf14fb9a59c7b8 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 15:43:04 +0200 Subject: [PATCH 17/29] update news --- NEWS.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/NEWS.md b/NEWS.md index df9ef4396..60a45e0a9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,11 @@ +# mlr3proba 0.7.3 + +* feat: added new calibration measure => `msr("surv.calib_index")` +* refactor: `autoplot.PredictionSurv` + * The default `"calib"` plot uses the survival matrix directly now (so faster) + * `"dcalib"` has extra barplot + better documentation + * **BREAKING CHANGE**: `"preds"` is now called `"isd"` (individual survival distribution) and `row_ids` can be used to filter the observations for which you draw the survival curves. + # mlr3proba 0.7.2 * fix: `lrn("surv.coxph")` is now trained with `model=TRUE` which fixes an issue with using observation weights [stackoverflow link](https://stackoverflow.com/questions/79297386/mlr3-predicted-values-for-surv-coxph-learner-with-case-weights). From 87ff0f9a0a354d24604eeeda2a64532f3cca8926 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 17:06:52 +0200 Subject: [PATCH 18/29] avoid Infs when log(cdf == 0) --- R/MeasureSurvICI.R | 5 ++++- man/mlr_measures_surv.calib_index.Rd | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/R/MeasureSurvICI.R b/R/MeasureSurvICI.R index de2b35af0..3ac10fd7e 100644 --- a/R/MeasureSurvICI.R +++ b/R/MeasureSurvICI.R @@ -16,7 +16,9 @@ #' calibration curve is estimated by fitting the following model: #' \deqn{log(h(t)) = g(log(− log(1 − \hat{P}_{t_0})), t)} #' -#' Any \eqn{\hat{P}_{t_0} = 1} is set to \eqn{0.9999} to avoid calculating \eqn{log(0)}. +#' Note that we substitute any \eqn{\hat{P}_{t_0} = 1} with \eqn{0.9999} and any +#' \eqn{\hat{P}_{t_0} = 0} with \eqn{0.0001} to avoid arithmetic issues arising +#' from calculating \eqn{log(0)}. #' From this model, the *smoothed* probability of occurrence at \eqn{t_0} for #' observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}. #' @@ -119,6 +121,7 @@ MeasureSurvICI = R6Class("MeasureSurvICI", cdf = as.vector(extend_times_cdf(time, pred_times, cdf = t(1 - surv), TRUE, FALSE)) # to avoid log(0) later, same as in paper's Appendix cdf[cdf == 1] = 0.9999 + cdf[cdf == 0] = 0.0001 # get the cdf complement (survival) log-log transformed cll = log(-log(1 - cdf)) diff --git a/man/mlr_measures_surv.calib_index.Rd b/man/mlr_measures_surv.calib_index.Rd index 6e24a9bf8..fd9ebb904 100644 --- a/man/mlr_measures_surv.calib_index.Rd +++ b/man/mlr_measures_surv.calib_index.Rd @@ -18,7 +18,9 @@ Using hazard regression (via the \CRANpkg{polspline} R package), a \emph{smoothe calibration curve is estimated by fitting the following model: \deqn{log(h(t)) = g(log(− log(1 − \hat{P}_{t_0})), t)} -Any \eqn{\hat{P}_{t_0} = 1} is set to \eqn{0.9999} to avoid calculating \eqn{log(0)}. +Note that we substitute any \eqn{\hat{P}_{t_0} = 1} with \eqn{0.9999} and any +\eqn{\hat{P}_{t_0} = 0} with \eqn{0.0001} to avoid arithmetic issues arising +from calculating \eqn{log(0)}. From this model, the \emph{smoothed} probability of occurrence at \eqn{t_0} for observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}. From 0c076b35361a65c9fc81102d726986646542790d Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 17:18:11 +0200 Subject: [PATCH 19/29] add "scalib" option for PredictionSurv objects --- R/autoplot.R | 83 +++++++++++++++++++++++++++++----- man/autoplot.PredictionSurv.Rd | 24 +++++++++- tests/testthat/test_autoplot.R | 3 ++ 3 files changed, 97 insertions(+), 13 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 921f06144..56aa71193 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -168,10 +168,17 @@ plot.TaskDens = function(x, ...) { #' matches `p`. For example, 50% of events should occur before the predicted #' median survival time (i.e. the time corresponding to a predicted survival #' probability of 0.5). -#' This means that the resulting line plot will lie close to the straight line -#' y = x. +#' Good calibration means that the resulting line plot will lie close to the +#' straight line \eqn{y = x}. #' Note that we impute `NA`s from the predicted quantile function with the #' maximum observed outcome time. +#' - `"scalib"`: **Smoothed calibration plot** at a specific time point. +#' For a range of predicted probabilities of event occurrence in \eqn{[0,1]} (x-axis), +#' the y-axis has the smoothed observed proportions calculated using hazard +#' regression. +#' See Austin et al. (2020) and [MeasureSurvICI] for more details. +#' Good calibration means that the resulting line plot will lie close to the +#' straight line \eqn{y = x}. #' - `"isd"`: Plot the predicted **i**ndividual **s**urvival **d**istributions #' (survival curves) for observations from the test set. #' @@ -193,6 +200,9 @@ plot.TaskDens = function(x, ...) { #' if `NULL` uses all time points from the predicted survival matrix (`object$data$distr`). #' @param cuts (`integer(1)`) \cr #' Number of cuts in \eqn{(0,1)} to plot `dcalib` over, default is `11`. +#' @param time (`numeric(1)`) \cr +#' The specific time point at which the smoothed calibration plot is created. +#' Must be always provided if `type = "scalib"`. #' @template param_theme #' @param ... (`any`): #' Additional arguments, currently unused. @@ -200,7 +210,7 @@ plot.TaskDens = function(x, ...) { #' @template section_theme #' #' @references -#' `r format_bib("haider_2020")` +#' `r format_bib("haider_2020", "austin2020")` #' #' @examplesIf mlr3misc::require_namespaces(c("mlr3viz", "ggplot2"), quietly = TRUE) #' library(mlr3) @@ -220,6 +230,9 @@ plot.TaskDens = function(x, ...) { #' # Distribution-calibration (D-Calibration) #' autoplot(p, type = "dcalib") #' +#' # Smoothed Calibration (S-Calibration) +#' autoplot(p, type = "scalib", time = 750) +#' #' # Predicted survival curves (all observations) #' autoplot(p, type = "isd") #' @@ -228,8 +241,8 @@ plot.TaskDens = function(x, ...) { #' #' @export autoplot.PredictionSurv = function(object, type = "calib", - times = NULL, row_ids = NULL, cuts = 11L, theme = theme_minimal(), ...) { - assert_choice(type, c("calib", "dcalib", "isd"), null.ok = FALSE) + times = NULL, row_ids = NULL, cuts = 11L, time = NULL, theme = theme_minimal(), ...) { + assert_choice(type, c("calib", "dcalib", "scalib", "isd"), null.ok = FALSE) assert("distr" %in% object$predict_types) assert_number(cuts, na.ok = FALSE, lower = 1L, null.ok = FALSE) assert_numeric(row_ids, any.missing = FALSE, lower = 1, null.ok = TRUE) @@ -292,16 +305,65 @@ autoplot.PredictionSurv = function(object, type = "calib", }) ggplot(data = data.table(p, q), aes(x = p, y = q)) + - geom_bar(stat = "identity", fill = "skyblue", color = "black") + - geom_line(color = "red") + + geom_bar(stat = "identity", fill = "#5dadc8") + + geom_line(color = "black") + scale_x_continuous(breaks = p) + - annotate("segment", x = 0, y = 0, xend = 1, yend = 1, color = "black", + annotate("segment", x = 0, y = 0, xend = 1, yend = 1, alpha = 0.5, linetype = "dashed") + labs(x = "Survival Probability (Bins)", y = "Observed Proportion") + theme }, + "scalib" = { + requireNamespace("polspline") + # test set survival outcome + times = object$truth[, 1L] + status = object$truth[, 2L] + # time point for plotting calibration curve + time = assert_number(time, na.ok = FALSE, lower = 0, null.ok = FALSE) + + # get predicted survival matrix + if (inherits(object$data$distr, "array")) { + surv = object$data$distr + if (length(dim(surv)) == 3L) { + # survival 3d array, extract median + surv = .ext_surv_mat(arr = surv, which.curve = 0.5) + } + } else { + stop("Distribution prediction does not have a survival matrix or array + in the $data$distr slot") + } + + # get cdf at the specified time point + extend_times_cdf = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6") + pred_times = as.numeric(colnames(surv)) + cdf = as.vector(extend_times_cdf(time, pred_times, cdf = t(1 - surv), TRUE, FALSE)) + # to avoid log(0) later, same as in paper's Appendix + cdf[cdf == 1] = 0.9999 + + # get the cdf complement (survival) log-log transformed + cll = log(-log(1 - cdf)) + + hare_fit = polspline::hare(data = times, delta = status, cov = as.matrix(cll)) + + # make a wide-range of cdf probabilities + cdf_grid = seq(0.001, 0.999, 0.001) + cll_grid = log(-log(1 - cdf_grid)) + + smoothed_cdf_grid = polspline::phare(q = time, cov = cll_grid, fit = hare_fit) + data = data.table(pred = cdf_grid, obs = smoothed_cdf_grid) + + ggplot(data, aes(x = pred, y = obs)) + + geom_line() + + annotate("segment", x = 0, y = 0, xend = 1, yend = 1, alpha = 0.5, + linetype = "dashed") + + labs(x = "Predicted probability", + y = "Observed probability", + title = paste0("t = ", time)) + + theme + }, + "isd" = { surv = object$data$distr # assume this is 2d survival matrix data = data.table( @@ -315,9 +377,8 @@ autoplot.PredictionSurv = function(object, type = "calib", data = data[get("row_id") %in% row_ids] } - p = - ggplot(data, aes(x = .data[["time"]], y = .data[["surv_prob"]], - group = .data[["row_id"]], color = .data[["row_id"]])) + + p = ggplot(data, aes(x = .data[["time"]], y = .data[["surv_prob"]], + group = .data[["row_id"]], color = .data[["row_id"]])) + geom_line() + labs(x = "Time", y = "Survival Probability") + theme diff --git a/man/autoplot.PredictionSurv.Rd b/man/autoplot.PredictionSurv.Rd index 7a159aaa1..efbab80e6 100644 --- a/man/autoplot.PredictionSurv.Rd +++ b/man/autoplot.PredictionSurv.Rd @@ -10,6 +10,7 @@ times = NULL, row_ids = NULL, cuts = 11L, + time = NULL, theme = theme_minimal(), ... ) @@ -31,6 +32,10 @@ we draw their predicted survival distributions.} \item{cuts}{(\code{integer(1)}) \cr Number of cuts in \eqn{(0,1)} to plot \code{dcalib} over, default is \code{11}.} +\item{time}{(\code{numeric(1)}) \cr +The specific time point at which the smoothed calibration plot is created. +Must be always provided if \code{type = "scalib"}.} + \item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} @@ -49,10 +54,17 @@ proportion of observed outcomes occurring before the predicted time quantile, matches \code{p}. For example, 50\% of events should occur before the predicted median survival time (i.e. the time corresponding to a predicted survival probability of 0.5). -This means that the resulting line plot will lie close to the straight line -y = x. +Good calibration means that the resulting line plot will lie close to the +straight line \eqn{y = x}. Note that we impute \code{NA}s from the predicted quantile function with the maximum observed outcome time. +\item \code{"scalib"}: \strong{Smoothed calibration plot} at a specific time point. +For a range of predicted probabilities of event occurrence in \eqn{[0,1]} (x-axis), +the y-axis has the smoothed observed proportions calculated using hazard +regression. +See Austin et al. (2020) and \link{MeasureSurvICI} for more details. +Good calibration means that the resulting line plot will lie close to the +straight line \eqn{y = x}. \item \code{"isd"}: Plot the predicted \strong{i}ndividual \strong{s}urvival \strong{d}istributions (survival curves) for observations from the test set. } @@ -86,6 +98,9 @@ autoplot(p, times = seq(1, 1000, 5)) # Distribution-calibration (D-Calibration) autoplot(p, type = "dcalib") +# Smoothed Calibration (S-Calibration) +autoplot(p, type = "scalib", time = 750) + # Predicted survival curves (all observations) autoplot(p, type = "isd") @@ -98,4 +113,9 @@ Haider, Humza, Hoehn, Bret, Davis, Sarah, Greiner, Russell (2020). \dQuote{Effective Ways to Build and Evaluate Individual Survival Distributions.} \emph{Journal of Machine Learning Research}, \bold{21}(85), 1--63. \url{https://jmlr.org/papers/v21/18-772.html}. + +Austin, C. P, Harrell, E. F, van Klaveren, David (2020). +\dQuote{Graphical calibration curves and the integrated calibration index (ICI) for survival models.} +\emph{Statistics in Medicine}, \bold{39}(21), 2714. +ISSN 10970258, \doi{10.1002/SIM.8570}, \url{https://pmc.ncbi.nlm.nih.gov/articles/PMC7497089/}. } diff --git a/tests/testthat/test_autoplot.R b/tests/testthat/test_autoplot.R index bf0f86fae..f1f66f936 100644 --- a/tests/testthat/test_autoplot.R +++ b/tests/testthat/test_autoplot.R @@ -23,6 +23,9 @@ test_that("autoplot.PredictionSurv", { p = autoplot(prediction, type = "dcalib", cuts = 4) expect_true(is.ggplot(p)) + p = autoplot(prediction, type = "scalib", time = 95) + expect_true(is.ggplot(p)) + p = autoplot(prediction, type = "isd", row_ids = sample(task$row_ids, size = 5)) expect_true(is.ggplot(p)) }) From ed9f7154d77fdacc7013ffc7b6ed5ec464e792a0 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 17:20:24 +0200 Subject: [PATCH 20/29] update news --- NEWS.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/NEWS.md b/NEWS.md index 60a45e0a9..62ecf3ee7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,10 +1,11 @@ # mlr3proba 0.7.3 * feat: added new calibration measure => `msr("surv.calib_index")` -* refactor: `autoplot.PredictionSurv` - * The default `"calib"` plot uses the survival matrix directly now (so faster) +* refator + feat: `autoplot.PredictionSurv` + * The default `"calib"` plot uses the survival matrix directly now which is faster * `"dcalib"` has extra barplot + better documentation - * **BREAKING CHANGE**: `"preds"` is now called `"isd"` (individual survival distribution) and `row_ids` can be used to filter the observations for which you draw the survival curves. + * Added new `type = "scalib"` which contructs the smoothed calibration plots as in Austin et al. (2020) + * **BREAKING CHANGE**: `"preds"` is now called `"isd"` (individual survival distribution). `row_ids` can now be used to filter the observations for which you draw the survival curves. # mlr3proba 0.7.2 From c9aaf320e75e66e956c89d4cb66ffa9ac0755583 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 17:32:23 +0200 Subject: [PATCH 21/29] refine doc --- R/autoplot.R | 21 +++++++++++---------- man/autoplot.PredictionSurv.Rd | 21 +++++++++++---------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 56aa71193..6aeab47b8 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -173,14 +173,14 @@ plot.TaskDens = function(x, ...) { #' Note that we impute `NA`s from the predicted quantile function with the #' maximum observed outcome time. #' - `"scalib"`: **Smoothed calibration plot** at a specific time point. -#' For a range of predicted probabilities of event occurrence in \eqn{[0,1]} (x-axis), +#' For a range of probabilities of event occurrence in \eqn{[0,1]} (x-axis), #' the y-axis has the smoothed observed proportions calculated using hazard -#' regression. +#' regression (model is fitted using the predicted probabilities). #' See Austin et al. (2020) and [MeasureSurvICI] for more details. #' Good calibration means that the resulting line plot will lie close to the #' straight line \eqn{y = x}. #' - `"isd"`: Plot the predicted **i**ndividual **s**urvival **d**istributions -#' (survival curves) for observations from the test set. +#' (survival curves) for the test set's observations. #' #' @section Notes: #' @@ -197,12 +197,13 @@ plot.TaskDens = function(x, ...) { #' we draw their predicted survival distributions. #' @param times (`numeric()`) \cr #' If `type = "calib"` then `times` is the values on the x-axis to plot over. -#' if `NULL` uses all time points from the predicted survival matrix (`object$data$distr`). +#' If `NULL`, we use all time points from the predicted survival matrix (`object$data$distr`). #' @param cuts (`integer(1)`) \cr -#' Number of cuts in \eqn{(0,1)} to plot `dcalib` over, default is `11`. +#' If `type = "calib"`, number of cuts in \eqn{(0,1)}, which define the bins on +#' the x-axis of the D-calibration plot. Default is `11`. #' @param time (`numeric(1)`) \cr -#' The specific time point at which the smoothed calibration plot is created. -#' Must be always provided if `type = "scalib"`. +#' If `type = "scalib"`, a specific time point at which the smoothed calibration +#' plot is constructed. #' @template param_theme #' @param ... (`any`): #' Additional arguments, currently unused. @@ -219,7 +220,7 @@ plot.TaskDens = function(x, ...) { #' #' learner = lrn("surv.coxph") #' task = tsk("gbcs") -#' p = learner$train(task, row_ids = 1:300)$predict(task, row_ids = 301:400) +#' p = learner$train(task, row_ids = 1:600)$predict(task, row_ids = 601:686) #' #' # calibration by comparison of average prediction to Kaplan-Meier #' autoplot(p) @@ -231,13 +232,13 @@ plot.TaskDens = function(x, ...) { #' autoplot(p, type = "dcalib") #' #' # Smoothed Calibration (S-Calibration) -#' autoplot(p, type = "scalib", time = 750) +#' autoplot(p, type = "scalib", time = 1750) #' #' # Predicted survival curves (all observations) #' autoplot(p, type = "isd") #' #' # Predicted survival curves (specific observations) -#' autoplot(p, type = "isd", row_ids = c(301, 351, 399)) +#' autoplot(p, type = "isd", row_ids = c(601, 651, 686)) #' #' @export autoplot.PredictionSurv = function(object, type = "calib", diff --git a/man/autoplot.PredictionSurv.Rd b/man/autoplot.PredictionSurv.Rd index efbab80e6..1be11ea7f 100644 --- a/man/autoplot.PredictionSurv.Rd +++ b/man/autoplot.PredictionSurv.Rd @@ -23,18 +23,19 @@ Type of the plot, see Description.} \item{times}{(\code{numeric()}) \cr If \code{type = "calib"} then \code{times} is the values on the x-axis to plot over. -if \code{NULL} uses all time points from the predicted survival matrix (\code{object$data$distr}).} +If \code{NULL}, we use all time points from the predicted survival matrix (\code{object$data$distr}).} \item{row_ids}{(\code{integer()}) \cr If \code{type = "isd"}, specific observation ids (from the test set) for which we draw their predicted survival distributions.} \item{cuts}{(\code{integer(1)}) \cr -Number of cuts in \eqn{(0,1)} to plot \code{dcalib} over, default is \code{11}.} +If \code{type = "calib"}, number of cuts in \eqn{(0,1)}, which define the bins on +the x-axis of the D-calibration plot. Default is \code{11}.} \item{time}{(\code{numeric(1)}) \cr -The specific time point at which the smoothed calibration plot is created. -Must be always provided if \code{type = "scalib"}.} +If \code{type = "scalib"}, a specific time point at which the smoothed calibration +plot is constructed.} \item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} @@ -59,14 +60,14 @@ straight line \eqn{y = x}. Note that we impute \code{NA}s from the predicted quantile function with the maximum observed outcome time. \item \code{"scalib"}: \strong{Smoothed calibration plot} at a specific time point. -For a range of predicted probabilities of event occurrence in \eqn{[0,1]} (x-axis), +For a range of probabilities of event occurrence in \eqn{[0,1]} (x-axis), the y-axis has the smoothed observed proportions calculated using hazard -regression. +regression (model is fitted using the predicted probabilities). See Austin et al. (2020) and \link{MeasureSurvICI} for more details. Good calibration means that the resulting line plot will lie close to the straight line \eqn{y = x}. \item \code{"isd"}: Plot the predicted \strong{i}ndividual \strong{s}urvival \strong{d}istributions -(survival curves) for observations from the test set. +(survival curves) for the test set's observations. } } \section{Notes}{ @@ -87,7 +88,7 @@ library(mlr3viz) learner = lrn("surv.coxph") task = tsk("gbcs") -p = learner$train(task, row_ids = 1:300)$predict(task, row_ids = 301:400) +p = learner$train(task, row_ids = 1:600)$predict(task, row_ids = 601:686) # calibration by comparison of average prediction to Kaplan-Meier autoplot(p) @@ -99,13 +100,13 @@ autoplot(p, times = seq(1, 1000, 5)) autoplot(p, type = "dcalib") # Smoothed Calibration (S-Calibration) -autoplot(p, type = "scalib", time = 750) +autoplot(p, type = "scalib", time = 1750) # Predicted survival curves (all observations) autoplot(p, type = "isd") # Predicted survival curves (specific observations) -autoplot(p, type = "isd", row_ids = c(301, 351, 399)) +autoplot(p, type = "isd", row_ids = c(601, 651, 686)) \dontshow{\}) # examplesIf} } \references{ From 4401671c6b099fe7b4c874d90da192fc686c0cd9 Mon Sep 17 00:00:00 2001 From: john Date: Sat, 4 Jan 2025 17:39:59 +0200 Subject: [PATCH 22/29] remove unneeded 0 --- R/MeasureSurvICI.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/MeasureSurvICI.R b/R/MeasureSurvICI.R index 3ac10fd7e..1ad99e640 100644 --- a/R/MeasureSurvICI.R +++ b/R/MeasureSurvICI.R @@ -138,7 +138,7 @@ MeasureSurvICI = R6Class("MeasureSurvICI", result = stats::median(abs(cdf - smoothed_cdf)) } else if (method == "E90") { # 90th percentile (E90) - result = stats::quantile(abs(cdf - smoothed_cdf), probs = 0.90) + result = stats::quantile(abs(cdf - smoothed_cdf), probs = 0.9) } else if (method == "Emax") { # Maximum absolute difference (Emax) result = max(abs(cdf - smoothed_cdf)) From 438d777b321663f64cfdfcae152cc694bc7a8d4b Mon Sep 17 00:00:00 2001 From: john Date: Sun, 5 Jan 2025 11:50:14 +0200 Subject: [PATCH 23/29] add param eps --- R/MeasureSurvICI.R | 24 +++++++++++++++--------- man/mlr_measures_surv.calib_index.Rd | 15 ++++++++++++--- tests/testthat/test_mlr_measures.R | 5 ++++- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/R/MeasureSurvICI.R b/R/MeasureSurvICI.R index 1ad99e640..4471eeeed 100644 --- a/R/MeasureSurvICI.R +++ b/R/MeasureSurvICI.R @@ -1,6 +1,8 @@ #' @template surv_measure #' @templateVar title Integrated Calibration Index #' @templateVar fullname MeasureSurvICI +#' @templateVar eps 1e-4 +#' @template param_eps #' #' @description #' Calculates the Integrated Calibration Index (ICI), see Austin et al. (2020). @@ -16,9 +18,9 @@ #' calibration curve is estimated by fitting the following model: #' \deqn{log(h(t)) = g(log(− log(1 − \hat{P}_{t_0})), t)} #' -#' Note that we substitute any \eqn{\hat{P}_{t_0} = 1} with \eqn{0.9999} and any -#' \eqn{\hat{P}_{t_0} = 0} with \eqn{0.0001} to avoid arithmetic issues arising -#' from calculating \eqn{log(0)}. +#' Note that we substitute probabilities \eqn{\hat{P}_{t_0} = 0} with a small +#' \eqn{\epsilon} number to avoid arithmetic issues (\eqn{log(0)}). Same with +#' \eqn{\hat{P}_{t_0} = 1}, we use \eqn{1 - \epsilon}. #' From this model, the *smoothed* probability of occurrence at \eqn{t_0} for #' observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}. #' @@ -77,9 +79,10 @@ MeasureSurvICI = R6Class("MeasureSurvICI", initialize = function() { param_set = ps( time = p_dbl(0, Inf), + eps = p_dbl(0, 1, default = 1e-4), method = p_fct(default = "ICI", levels = c("ICI", "E50", "E90", "Emax")) ) - param_set$set_values(method = "ICI") + param_set$set_values(method = "ICI", eps = 1e-4) super$initialize( id = "surv.calib_index", @@ -112,16 +115,19 @@ MeasureSurvICI = R6Class("MeasureSurvICI", in the $data$distr slot") } + pv = self$param_set$values + # time point for calibration - time = self$param_set$values$time %??% median(times) + time = pv$time %??% median(times) # get cdf at the specified time point extend_times_cdf = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6") pred_times = as.numeric(colnames(surv)) cdf = as.vector(extend_times_cdf(time, pred_times, cdf = t(1 - surv), TRUE, FALSE)) # to avoid log(0) later, same as in paper's Appendix - cdf[cdf == 1] = 0.9999 - cdf[cdf == 0] = 0.0001 + eps = pv$eps + cdf[cdf == 1] = 1 - eps + cdf[cdf == 0] = eps # get the cdf complement (survival) log-log transformed cll = log(-log(1 - cdf)) @@ -129,7 +135,7 @@ MeasureSurvICI = R6Class("MeasureSurvICI", hare_fit = polspline::hare(data = times, delta = status, cov = as.matrix(cll)) smoothed_cdf = polspline::phare(q = time, cov = cll, fit = hare_fit) - method = self$param_set$values$method + method = pv$method if (method == "ICI") { # Mean difference (ICI) result = mean(abs(cdf - smoothed_cdf)) @@ -144,7 +150,7 @@ MeasureSurvICI = R6Class("MeasureSurvICI", result = max(abs(cdf - smoothed_cdf)) } - return(result) + result } ) ) diff --git a/man/mlr_measures_surv.calib_index.Rd b/man/mlr_measures_surv.calib_index.Rd index fd9ebb904..9e69ffe40 100644 --- a/man/mlr_measures_surv.calib_index.Rd +++ b/man/mlr_measures_surv.calib_index.Rd @@ -18,9 +18,9 @@ Using hazard regression (via the \CRANpkg{polspline} R package), a \emph{smoothe calibration curve is estimated by fitting the following model: \deqn{log(h(t)) = g(log(− log(1 − \hat{P}_{t_0})), t)} -Note that we substitute any \eqn{\hat{P}_{t_0} = 1} with \eqn{0.9999} and any -\eqn{\hat{P}_{t_0} = 0} with \eqn{0.0001} to avoid arithmetic issues arising -from calculating \eqn{log(0)}. +Note that we substitute probabilities \eqn{\hat{P}_{t_0} = 0} with a small +\eqn{\epsilon} number to avoid arithmetic issues (\eqn{log(0)}). Same with +\eqn{\hat{P}_{t_0} = 1}, we use \eqn{1 - \epsilon}. From this model, the \emph{smoothed} probability of occurrence at \eqn{t_0} for observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}. @@ -45,6 +45,7 @@ msr("surv.calib_index") \tabular{lllll}{ Id \tab Type \tab Default \tab Levels \tab Range \cr time \tab numeric \tab - \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + eps \tab numeric \tab 1e-04 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr method \tab character \tab ICI \tab ICI, E50, E90, Emax \tab - \cr } } @@ -61,6 +62,14 @@ msr("surv.calib_index") \section{Parameter details}{ +\itemize{ +\item \code{eps} (\code{numeric(1)})\cr +Very small number to substitute zero values in order to prevent errors +in e.g. log(0) and/or division-by-zero calculations. +Default value is 1e-04. +} + + \itemize{ \item \code{time} (\code{numeric(1)})\cr The specific time point \eqn{t_0} at which calibration is evaluated. diff --git a/tests/testthat/test_mlr_measures.R b/tests/testthat/test_mlr_measures.R index 4e103c971..24b078fff 100644 --- a/tests/testthat/test_mlr_measures.R +++ b/tests/testthat/test_mlr_measures.R @@ -123,10 +123,12 @@ test_that("calib_index works", { m = msr("surv.calib_index") expect_equal(m$range, c(0, Inf)) expect_true(m$minimize) - expect_equal(m$param_set$values$method, "ICI") + expect_equal(m$param_set$values$method, "ICI") # mean abs diffs + expect_equal(m$param_set$values$eps, 0.0001) res = pred$score(m) expect_gt(res, 0) + # scores for E90 and Emax represent more extreme (larger) differences than the mean m2 = msr("surv.calib_index", method = "E90") expect_equal(m2$param_set$values$method, "E90") res2 = pred$score(m2) @@ -136,6 +138,7 @@ test_that("calib_index works", { expect_equal(m3$param_set$values$method, "Emax") expect_gt(pred$score(m3), res2) + # different time point m4 = msr("surv.calib_index", time = 100) expect_equal(m4$param_set$values$time, 100) expect_false(pred$score(m4) == res) From 3a5a2f3571d1c622655af0328310fed0b90305d9 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 6 Jan 2025 13:44:14 +0200 Subject: [PATCH 24/29] some doc refinements --- R/MeasureSurvICI.R | 5 +++-- man/mlr_measures_surv.calib_index.Rd | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/R/MeasureSurvICI.R b/R/MeasureSurvICI.R index 4471eeeed..c78d513e5 100644 --- a/R/MeasureSurvICI.R +++ b/R/MeasureSurvICI.R @@ -24,7 +24,8 @@ #' From this model, the *smoothed* probability of occurrence at \eqn{t_0} for #' observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}. #' -#' The **Integrated Calibration Index** is then computed as: +#' The **Integrated Calibration Index** is then computed across the \eqn{N} +#' test set observations as: #' \deqn{ICI = \frac{1}{N} \sum_{i=1}^N | \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |} #' #' This measure evaluates **point-calibration** at a specific time point, which @@ -62,7 +63,7 @@ #' # Make predictions for the test set #' p = cox$predict(task, row_ids = part$test) #' -#' # ICI at median test time +#' # ICI at median test set time #' p$score(msr("surv.calib_index")) #' #' # ICI at specific time point diff --git a/man/mlr_measures_surv.calib_index.Rd b/man/mlr_measures_surv.calib_index.Rd index 9e69ffe40..ca2de283e 100644 --- a/man/mlr_measures_surv.calib_index.Rd +++ b/man/mlr_measures_surv.calib_index.Rd @@ -24,7 +24,8 @@ Note that we substitute probabilities \eqn{\hat{P}_{t_0} = 0} with a small From this model, the \emph{smoothed} probability of occurrence at \eqn{t_0} for observation \eqn{i} is obtained as \eqn{\hat{P}_i^c(t_0)}. -The \strong{Integrated Calibration Index} is then computed as: +The \strong{Integrated Calibration Index} is then computed across the \eqn{N} +test set observations as: \deqn{ICI = \frac{1}{N} \sum_{i=1}^N | \hat{P}_i^c(t_0) - \hat{P}_i(t_0) |} This measure evaluates \strong{point-calibration} at a specific time point, which @@ -101,7 +102,7 @@ cox$train(task, row_ids = part$train) # Make predictions for the test set p = cox$predict(task, row_ids = part$test) -# ICI at median test time +# ICI at median test set time p$score(msr("surv.calib_index")) # ICI at specific time point From bc8e40f3dd5ca5b4efb5d3af32ad66c12ca95dcb Mon Sep 17 00:00:00 2001 From: john Date: Mon, 6 Jan 2025 13:44:31 +0200 Subject: [PATCH 25/29] add Lee as contributor (ICI measure) --- DESCRIPTION | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index dffcc2a9b..b23581dde 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -16,7 +16,9 @@ Authors@R: c( comment = c(ORCID = "0000-0001-7528-3795")), person("Philip", "Studener", , "philip.studener@gmx.de", role = "aut"), person("Maximilian", "Muecke", , "muecke.maximilian@gmail.com", role = "ctb", - comment = c(ORCID = "0009-0000-9432-9795")) + comment = c(ORCID = "0009-0000-9432-9795")), + person("Lee Xingzhuo", "Li", , "xingzhuo_li@yahoo.com.au", role = "ctb", + comment = c(ORCID = "0000-0001-5259-5198")) ) Description: Provides extensions for probabilistic supervised learning for 'mlr3'. This includes extending the regression task to probabilistic From 98de7ff0b55660e6508e30bf11e60d0ebbcb5a0c Mon Sep 17 00:00:00 2001 From: john Date: Mon, 6 Jan 2025 16:38:45 +0200 Subject: [PATCH 26/29] refine install section + measures table + add ICI --- README.Rmd | 27 ++++++++------------------- README.md | 35 +++++++++++------------------------ 2 files changed, 19 insertions(+), 43 deletions(-) diff --git a/README.Rmd b/README.Rmd index 30d4964dc..6c509a944 100644 --- a/README.Rmd +++ b/README.Rmd @@ -56,23 +56,11 @@ Please follow one of the two following methods to install it: ### R-universe +Install the latest released version: ```r install.packages("mlr3proba", repos = "https://mlr-org.r-universe.dev") ``` -Or for easier installation going forward: - -1. Run `usethis::edit_r_environ()` then in the file that opened add or edit `options` to look something like: -```r -options(repos = c( - raphaels1 = "https://raphaels1.r-universe.dev", - mlrorg = "https://mlr-org.r-universe.dev", # add this line - CRAN = "https://cloud.r-project.org" -)) -``` -2. Save and close the file, restart your `R` session -3. Run `install.packages("mlr3proba")` as usual - ### GitHub Install the latest development version: @@ -95,12 +83,13 @@ Some commonly used measures are the following: | ID | Measure | Package | Category | Prediction Type | :--| :------ | :------ | :------ | :------- | -| [surv.dcalib](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.dcalib.html) | D-Calibration | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Calibration | distr -| [surv.cindex](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.cindex.html) | Concordance Index | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Discrimination | crank -| [surv.uno_auc](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.uno_auc.html) | Uno's AUC | [survAUC](https://CRAN.R-project.org/package=survAUC) | Discrimination | lp -| [surv.graf](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.graf.html) | Integrated Brier Score | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Scoring Rule | distr -| [surv.rcll](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.rcll.html) | Right-Censored Log loss | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Scoring Rule | distr -| [surv.intlogloss](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.intlogloss.html) | Integrated Log-Likelihood | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Scoring Rule | distr +| [surv.dcalib](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.dcalib.html) | D-Calibration | `mlr3proba` | Calibration | `distr` +| [surv.calib_index](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.calib_index.html) | One-point Calibration | `mlr3proba` | Calibration | `distr` +| [surv.cindex](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.cindex.html) | Concordance Index | `mlr3proba` | Discrimination | `crank` +| [surv.uno_auc](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.uno_auc.html) | Uno's AUC | `survAUC` | Discrimination | `lp` +| [surv.graf](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.graf.html) | Integrated Brier Score | `mlr3proba` | Scoring Rule | `distr` +| [surv.rcll](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.rcll.html) | Right-Censored Log loss | `mlr3proba` | Scoring Rule | `distr` +| [surv.intlogloss](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.intlogloss.html) | Integrated Log-Likelihood | `mlr3proba` | Scoring Rule | `distr` ## Bugs, Questions, Feedback diff --git a/README.md b/README.md index 0f68c4050..32b9423d2 100644 --- a/README.md +++ b/README.md @@ -61,26 +61,12 @@ following methods to install it: ### R-universe -``` r -install.packages("mlr3proba", repos = "https://mlr-org.r-universe.dev") -``` - -Or for easier installation going forward: - -1. Run `usethis::edit_r_environ()` then in the file that opened add or - edit `options` to look something like: +Install the latest released version: ``` r -options(repos = c( - raphaels1 = "https://raphaels1.r-universe.dev", - mlrorg = "https://mlr-org.r-universe.dev", # add this line - CRAN = "https://cloud.r-project.org" -)) +install.packages("mlr3proba", repos = "https://mlr-org.r-universe.dev") ``` -2. Save and close the file, restart your `R` session -3. Run `install.packages("mlr3proba")` as usual - ### GitHub Install the latest development version: @@ -111,14 +97,15 @@ list Some commonly used measures are the following: -| ID | Measure | Package | Category | Prediction Type | -|:---------------------------------------------------------------------------------------------|:--------------------------|:----------------------------------------------------------|:---------------|:----------------| -| [surv.dcalib](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.dcalib.html) | D-Calibration | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Calibration | distr | -| [surv.cindex](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.cindex.html) | Concordance Index | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Discrimination | crank | -| [surv.uno_auc](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.uno_auc.html) | Uno’s AUC | [survAUC](https://CRAN.R-project.org/package=survAUC) | Discrimination | lp | -| [surv.graf](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.graf.html) | Integrated Brier Score | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Scoring Rule | distr | -| [surv.rcll](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.rcll.html) | Right-Censored Log loss | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Scoring Rule | distr | -| [surv.intlogloss](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.intlogloss.html) | Integrated Log-Likelihood | [mlr3proba](https://CRAN.R-project.org/package=mlr3proba) | Scoring Rule | distr | +| ID | Measure | Package | Category | Prediction Type | +|:---|:---|:---|:---|:---| +| [surv.dcalib](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.dcalib.html) | D-Calibration | `mlr3proba` | Calibration | `distr` | +| [surv.calib_index](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.calib_index.html) | One-point Calibration | `mlr3proba` | Calibration | `distr` | +| [surv.cindex](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.cindex.html) | Concordance Index | `mlr3proba` | Discrimination | `crank` | +| [surv.uno_auc](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.uno_auc.html) | Uno’s AUC | `survAUC` | Discrimination | `lp` | +| [surv.graf](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.graf.html) | Integrated Brier Score | `mlr3proba` | Scoring Rule | `distr` | +| [surv.rcll](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.rcll.html) | Right-Censored Log loss | `mlr3proba` | Scoring Rule | `distr` | +| [surv.intlogloss](https://mlr3proba.mlr-org.com/reference/mlr_measures_surv.intlogloss.html) | Integrated Log-Likelihood | `mlr3proba` | Scoring Rule | `distr` | ## Bugs, Questions, Feedback From 9dd5e6903737ffb352758e0d3c2884dfb5a82e6e Mon Sep 17 00:00:00 2001 From: john Date: Mon, 6 Jan 2025 16:39:27 +0200 Subject: [PATCH 27/29] pump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index b23581dde..d5a38f47c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: mlr3proba Title: Probabilistic Supervised Learning for 'mlr3' -Version: 0.7.1 +Version: 0.7.3 Authors@R: c( person("Raphael", "Sonabend", , "raphaelsonabend@gmail.com", role = "aut", comment = c(ORCID = "0000-0001-9225-4654")), From 3cbc8e66f4797d9b7c14d3a8a17b8a0d32725aa5 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 6 Jan 2025 16:42:54 +0200 Subject: [PATCH 28/29] correct spelling --- NEWS.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index 62ecf3ee7..61e350c60 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,10 +1,10 @@ # mlr3proba 0.7.3 * feat: added new calibration measure => `msr("surv.calib_index")` -* refator + feat: `autoplot.PredictionSurv` +* refactor + feat: `autoplot.PredictionSurv` * The default `"calib"` plot uses the survival matrix directly now which is faster * `"dcalib"` has extra barplot + better documentation - * Added new `type = "scalib"` which contructs the smoothed calibration plots as in Austin et al. (2020) + * Added new `type = "scalib"` which constructs the smoothed calibration plots as in Austin et al. (2020) * **BREAKING CHANGE**: `"preds"` is now called `"isd"` (individual survival distribution). `row_ids` can now be used to filter the observations for which you draw the survival curves. # mlr3proba 0.7.2 From 46b4dddfbb05b8fcb36b1af0aaf63edf71025995 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 6 Jan 2025 16:54:22 +0200 Subject: [PATCH 29/29] fix NOTEs --- DESCRIPTION | 2 +- R/MeasureSurvICI.R | 2 +- R/autoplot.R | 1 + man/mlr3proba-package.Rd | 1 + 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index d5a38f47c..3ecd1bf2e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -37,7 +37,6 @@ Imports: ggplot2, mlr3misc (>= 0.7.0), mlr3pipelines (>= 0.7.0), - mlr3viz, paradox (>= 1.0.0), R6, Rcpp (>= 1.0.4), @@ -50,6 +49,7 @@ Suggests: lgr, lifecycle, mlr3learners, + mlr3viz, pammtools, param6 (>= 0.2.4), polspline, diff --git a/R/MeasureSurvICI.R b/R/MeasureSurvICI.R index c78d513e5..dc056faf6 100644 --- a/R/MeasureSurvICI.R +++ b/R/MeasureSurvICI.R @@ -119,7 +119,7 @@ MeasureSurvICI = R6Class("MeasureSurvICI", pv = self$param_set$values # time point for calibration - time = pv$time %??% median(times) + time = pv$time %??% stats::median(times) # get cdf at the specified time point extend_times_cdf = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6") diff --git a/R/autoplot.R b/R/autoplot.R index 6aeab47b8..7370ca963 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -353,6 +353,7 @@ autoplot.PredictionSurv = function(object, type = "calib", cll_grid = log(-log(1 - cdf_grid)) smoothed_cdf_grid = polspline::phare(q = time, cov = cll_grid, fit = hare_fit) + pred = obs = NULL data = data.table(pred = cdf_grid, obs = smoothed_cdf_grid) ggplot(data, aes(x = pred, y = obs)) + diff --git a/man/mlr3proba-package.Rd b/man/mlr3proba-package.Rd index 479717393..44ce0142c 100644 --- a/man/mlr3proba-package.Rd +++ b/man/mlr3proba-package.Rd @@ -34,6 +34,7 @@ Other contributors: \item Andreas Bender \email{bender.at.R@gmail.com} (\href{https://orcid.org/0000-0001-5628-8611}{ORCID}) [contributor] \item Lukas Burk \email{github@quantenbrot.de} (\href{https://orcid.org/0000-0001-7528-3795}{ORCID}) [contributor] \item Maximilian Muecke \email{muecke.maximilian@gmail.com} (\href{https://orcid.org/0009-0000-9432-9795}{ORCID}) [contributor] + \item Lee Xingzhuo Li \email{xingzhuo_li@yahoo.com.au} (\href{https://orcid.org/0000-0001-5259-5198}{ORCID}) [contributor] } }