-
Notifications
You must be signed in to change notification settings - Fork 344
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: adapted to dynamic convergence rules
- Loading branch information
1 parent
50f559b
commit 95091c9
Showing
4 changed files
with
60 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,21 +11,29 @@ | |
#' | ||
#' @param OutputModels List. Output from \code{robyn_run()} | ||
#' @param n_cuts Integer. Default to 20 (5% cuts). Convergence is calculated | ||
#' on using first and last quantile cuts. Criteria 1: last quantile's sd | ||
#' < threshold_sd. Criteria 2: last quantile's median < first quantile's | ||
#' median - 2 * sd. Both have to happen to consider convergence. | ||
#' @param threshold_sd Numeric. Default to 0.025 that is empirically derived. | ||
#' on using first and last quantile cuts. By default, criteria 1: last | ||
#' quantile's sd < first 3 quantiles' mean sd. Criteria 2: last quantile's | ||
#' median < first quantile's median - 3 * first 3 quantiles' mean sd. Both | ||
#' have to be satisfied to consider convergence. | ||
#' @param sd_qtref Integer. Reference quantile of the error convergence rule | ||
#' for standard deviation. Defaults to 3. Error convergence rule for sd is | ||
#' defined as by default: last quantile's sd < first 3 quantiles' mean sd. | ||
#' @param med_lowb Integer. Lower bound distance of the error convergence rule | ||
#' for median. Default to 3. Error convergence rule for median is defined as | ||
#' by default: last quantile's median < first quantile's median - 3 * first 3 | ||
#' quantiles' mean sd. | ||
#' @param ... Additional parameters | ||
#' @examples | ||
#' \dontrun{ | ||
#' OutputModels <- robyn_converge( | ||
#' OutputModels = OutputModels, | ||
#' n_cuts = 10, | ||
#' threshold_sd = 0.025 | ||
#' n_cuts = 20, | ||
#' sd_qtref = 3, | ||
#' med_lowb = 3 | ||
#' ) | ||
#' } | ||
#' @export | ||
robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...) { | ||
robyn_converge <- function(OutputModels, n_cuts = 20, sd_qtref = 3, med_lowb = 3, ...) { | ||
|
||
# Gather all trials | ||
get_lists <- as.logical(grepl("trial", names(OutputModels)) * sapply(OutputModels, is.list)) | ||
|
@@ -54,8 +62,8 @@ robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...) | |
)) | ||
|
||
# Calculate sd and median on each cut to alert user on: | ||
# 1) last quantile's sd < threshold_sd | ||
# 2) last quantile's median < first quantile's median - 2 * sd | ||
# 1) last quantile's sd < mean sd of default first 3 qt | ||
# 2) last quantile's median < median of first qt - default 3 * mean sd of defualt first 3 qt | ||
errors <- dt_objfunc_cvg %>% | ||
group_by(.data$error_type, .data$cuts) %>% | ||
summarise( | ||
|
@@ -66,29 +74,37 @@ robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...) | |
) %>% | ||
group_by(.data$error_type) %>% | ||
mutate( | ||
med_var_P = abs(round(100 * (.data$median - lag(.data$median)) / .data$median, 2)), | ||
flag_sd = .data$std > threshold_sd | ||
med_var_P = abs(round(100 * (.data$median - lag(.data$median)) / .data$median, 2)) | ||
) %>% | ||
group_by(.data$error_type) %>% | ||
mutate(flag_med = dplyr::last(.data$median[1]) < dplyr::first(.data$median[2]) - 2 * dplyr::first(.data$std)) | ||
mutate(first_med = dplyr::first(.data$median), | ||
first_med_avg = mean(.data$median[1:sd_qtref]), | ||
last_med = dplyr::last(.data$median), | ||
first_sd = dplyr::first(.data$std), | ||
first_sd_avg = mean(.data$std[1:sd_qtref]), | ||
last_sd = dplyr::last(.data$std)) %>% | ||
mutate(med_thres = .data$first_med - med_lowb * .data$first_sd_avg, | ||
flag_med = .data$median < .data$first_med - med_lowb * .data$first_sd_avg, | ||
flag_sd = .data$std < .data$first_sd_avg) | ||
|
||
conv_msg <- NULL | ||
for (obj_fun in unique(errors$error_type)) { | ||
temp.df <- filter(errors, .data$error_type == obj_fun) %>% | ||
mutate(median = signif(median, 2)) | ||
last.qt <- tail(temp.df, 1) | ||
temp <- glued(paste( | ||
"{error_type} {did}converged: sd {sd} @qt.{quantile} {symb_sd} {sd_threh} &", | ||
"med {qtn_median} @qt.{quantile} {symb_med} {med_threh} [email protected]2*sd"), | ||
"{error_type} {did}converged: sd@qt.{quantile} {sd} {symb_sd} {sd_threh} &", | ||
"med@qt.{quantile} {qtn_median} {symb_med} {med_threh} [email protected]{med_lowb}*sd"), | ||
error_type = last.qt$error_type, | ||
did = ifelse(last.qt$flag_sd | last.qt$flag_med, "NOT ", ""), | ||
sd = signif(last.qt$std, 1), | ||
symb_sd = ifelse(last.qt$flag_sd, ">", "<="), | ||
sd_threh = threshold_sd, | ||
quantile = round(100/n_cuts), | ||
qtn_median = temp.df$median[n_cuts], | ||
symb_med = ifelse(last.qt$flag_med, ">", "<="), | ||
med_threh = signif(temp.df$median[1] - 2 * temp.df$std[1], 2) | ||
did = ifelse(last.qt$flag_sd & last.qt$flag_med, "", "NOT "), | ||
sd = signif(last.qt$last_sd, 2), | ||
symb_sd = ifelse(last.qt$flag_sd, "<", ">="), | ||
sd_threh = signif(last.qt$first_sd_avg, 2), | ||
quantile = n_cuts, | ||
qtn_median = signif(last.qt$last_med, 2), | ||
symb_med = ifelse(last.qt$flag_med, "<", ">="), | ||
med_threh = signif(last.qt$med_thres, 2), | ||
med_lowb = med_lowb | ||
) | ||
conv_msg <- c(conv_msg, temp) | ||
} | ||
|
@@ -162,7 +178,8 @@ robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...) | |
errors = errors, | ||
conv_msg = conv_msg | ||
) | ||
attr(cvg_out, "threshold_sd") <- threshold_sd | ||
attr(cvg_out, "sd_qtref") <- sd_qtref | ||
attr(cvg_out, "med_lowb") <- med_lowb | ||
|
||
return(invisible(cvg_out)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters