Skip to content

Commit

Permalink
fix: epi_slide_opt window_size validation, fix test, fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dshemetov committed Oct 2, 2024
1 parent 34ad569 commit 44d354d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 59 deletions.
54 changes: 12 additions & 42 deletions R/slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -678,46 +678,16 @@ epi_slide_opt <- function(
ref_time_values <- sort(.ref_time_values)

# Handle window arguments
align <- rlang::arg_match(.align)
.align <- rlang::arg_match(.align)
time_type <- attr(.x, "metadata")$time_type
validate_slide_window_arg(.window_size, time_type)
if (identical(.window_size, Inf)) {
if (align == "right") {
before <- Inf
if (time_type %in% c("day", "week")) {
after <- as.difftime(0, units = glue::glue("{time_type}s"))
} else {
after <- 0
}
} else {
cli_abort(
"`epi_slide`: center and left alignment are not supported with an infinite window size."
)
}
} else {
if (align == "right") {
before <- .window_size - 1
if (time_type %in% c("day", "week")) {
after <- as.difftime(0, units = glue::glue("{time_type}s"))
} else {
after <- 0
}
} else if (align == "center") {
# For .window_size = 5, before = 2, after = 2. For .window_size = 4, before = 2, after = 1.
before <- floor(.window_size / 2)
after <- .window_size - before - 1
} else if (align == "left") {
if (time_type %in% c("day", "week")) {
before <- as.difftime(0, units = glue::glue("{time_type}s"))
} else {
before <- 0
}
after <- .window_size - 1
}
if (is.null(.window_size)) {
cli_abort("epi_slide: `.window_size` must be specified.")
}
validate_slide_window_arg(.window_size, time_type)
window_args <- get_before_after_from_window(.window_size, .align, time_type)

# Make a complete date sequence between min(.x$time_value) and max(.x$time_value).
date_seq_list <- full_date_seq(.x, before, after, time_type)
date_seq_list <- full_date_seq(.x, window_args$before, window_args$after, time_type)
all_dates <- date_seq_list$all_dates
pad_early_dates <- date_seq_list$pad_early_dates
pad_late_dates <- date_seq_list$pad_late_dates
Expand Down Expand Up @@ -786,16 +756,16 @@ epi_slide_opt <- function(
# `before` and `after` params. Right-aligned `frollmean` results'
# `ref_time_value`s will be `after` timesteps ahead of where they should
# be; shift results to the left by `after` timesteps.
if (before != Inf) {
window_size <- before + after + 1L
if (window_args$before != Inf) {
window_size <- window_args$before + window_args$after + 1L
roll_output <- .f(x = .data_group[, col_names_chr], n = window_size, ...)
} else {
window_size <- list(seq_along(.data_group$time_value))
roll_output <- .f(x = .data_group[, col_names_chr], n = window_size, adaptive = TRUE, ...)
}
if (after >= 1) {
if (window_args$after >= 1) {
.data_group[, result_col_names] <- purrr::map(roll_output, function(.x) {
c(.x[(after + 1L):length(.x)], rep(NA, after))
c(.x[(window_args$after + 1L):length(.x)], rep(NA, window_args$after))
})
} else {
.data_group[, result_col_names] <- roll_output
Expand All @@ -805,8 +775,8 @@ epi_slide_opt <- function(
for (i in seq_along(col_names_chr)) {
.data_group[, result_col_names[i]] <- .f(
x = .data_group[[col_names_chr[i]]],
before = as.numeric(before),
after = as.numeric(after),
before = as.numeric(window_args$before),
after = as.numeric(window_args$after),
...
)
}
Expand Down
11 changes: 4 additions & 7 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -999,10 +999,7 @@ test_sensible_int <- function(x, na.ok = FALSE, lower = -Inf, upper = Inf, # nol
if (null.ok && is.null(x)) {
TRUE
} else {
is.numeric(x) && test_int(x,
na.ok = FALSE, lower = -Inf, upper = Inf,
tol = sqrt(.Machine$double.eps)
)
is.numeric(x) && test_int(x, na.ok = na.ok, lower = lower, upper = upper, tol = tol)
}
}

Expand All @@ -1024,7 +1021,7 @@ validate_slide_window_arg <- function(arg, time_type, lower = 1, allow_inf = TRU

# nolint start: indentation_linter.
if (time_type == "day") {
if (!(test_sensible_int(arg, lower = 0L) ||
if (!(test_sensible_int(arg, lower = lower) ||
inherits(arg, "difftime") && length(arg) == 1L && units(arg) == "days" ||
allow_inf && identical(arg, Inf)
)) {
Expand All @@ -1037,13 +1034,13 @@ validate_slide_window_arg <- function(arg, time_type, lower = 1, allow_inf = TRU
msg <- glue::glue_collapse(c("length-1 difftime with units in weeks", inf_if_okay), " or ")
}
} else if (time_type == "yearmonth") {
if (!(test_sensible_int(arg, lower = 0L) ||
if (!(test_sensible_int(arg, lower = lower) ||
allow_inf && identical(arg, Inf)
)) {
msg <- glue::glue_collapse(c("non-negative integer", inf_if_okay), " or ")
}
} else if (time_type == "integer") {
if (!(test_sensible_int(arg, lower = 0L) ||
if (!(test_sensible_int(arg, lower = lower) ||
allow_inf && identical(arg, Inf)
)) {
msg <- glue::glue_collapse(c("non-negative integer", inf_if_okay), " or ")
Expand Down
12 changes: 2 additions & 10 deletions tests/testthat/test-epi_slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -711,17 +711,9 @@ test_that("epi_slide_opt helper `full_date_seq` returns expected date values", {

test_that("`epi_slide_opt` errors when passed non-`data.table`, non-`slider` functions", {
reexport_frollmean <- data.table::frollmean
expect_no_error(
epi_slide_opt(
test_data,
.col_names = value, .f = reexport_frollmean
)
)
expect_no_error(epi_slide_opt(test_data, .col_names = value, .f = reexport_frollmean, .window_size = 7))
expect_error(
epi_slide_opt(
test_data,
.col_names = value, .f = mean
),
epi_slide_opt(test_data, .col_names = value, .f = mean),
class = "epiprocess__epi_slide_opt__unsupported_slide_function"
)
})
Expand Down

0 comments on commit 44d354d

Please sign in to comment.