Skip to content

Commit

Permalink
Make epi_slide* all use autogrouping + make autogrouping temporary
Browse files Browse the repository at this point in the history
  • Loading branch information
brookslogan committed Nov 14, 2024
1 parent 5ffaf2f commit f9a8356
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 35 deletions.
59 changes: 30 additions & 29 deletions R/slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,16 @@ epi_slide <- function(

# Validate arguments
assert_class(.x, "epi_df")
if (checkmate::test_class(.x, "grouped_df")) {
.x_orig_groups <- groups(.x)
if (inherits(.x, "grouped_df")) {
expected_group_keys <- .x %>%
key_colnames(exclude = "time_value") %>%
sort()
if (!identical(.x %>% group_vars() %>% sort(), expected_group_keys)) {
cli_abort(
"epi_slide: `.x` must be either grouped by {expected_group_keys}. (Or you can just ungroup
`.x` and we'll do this grouping automatically.) You may need to aggregate your data first,
see aggregate_epi_df().",
"`.x` must be either grouped by {expected_group_keys} or ungrouped; if the latter,
we'll temporarily group by {expected_group_keys} for this operation. You may need
to aggregate your data first, see aggregate_epi_df().",
class = "epiprocess__epi_slide__invalid_grouping"
)
}
Expand Down Expand Up @@ -300,7 +301,6 @@ epi_slide <- function(
# `epi_slide_one_group`.
# - `...` from top of `epi_slide` are forwarded to `.f` here through
# group_modify and through the lambda.
.x_groups <- groups(.x)
result <- group_map(
.x,
.f = function(.data_group, .group_key, ...) {
Expand All @@ -324,7 +324,7 @@ epi_slide <- function(
filter(.real) %>%
select(-.real) %>%
arrange_col_canonical() %>%
group_by(!!!.x_groups)
group_by(!!!.x_orig_groups)

# If every group in epi_slide_one_group takes the
# length(available_ref_time_values) == 0 branch then we end up here.
Expand Down Expand Up @@ -691,12 +691,30 @@ epi_slide_opt <- function(
)
}

assert_class(.x, "epi_df")
.x_orig_groups <- groups(.x)
if (inherits(.x, "grouped_df")) {
expected_group_keys <- .x %>%
key_colnames(exclude = "time_value") %>%
sort()
if (!identical(.x %>% group_vars() %>% sort(), expected_group_keys)) {
cli_abort(
"`.x` must be either grouped by {expected_group_keys} or ungrouped; if the latter,
we'll temporarily group by {expected_group_keys} for this operation. You may need
to aggregate your data first, see aggregate_epi_df().",
class = "epiprocess__epi_slide__invalid_grouping"
)
}
} else {
.x <- group_epi_df(.x, exclude = "time_value")
}
if (nrow(.x) == 0L) {
cli_abort(
c(
"input data `.x` unexpectedly has 0 rows",
"i" = "If this computation is occuring within an `epix_slide` call,
check that `epix_slide` `.versions` argument was set appropriately"
check that `epix_slide` `.versions` argument was set appropriately
so that you don't get any completely-empty snapshots"
),
class = "epiprocess__epi_slide_opt__0_row_input",
epiprocess__x = .x
Expand Down Expand Up @@ -857,27 +875,9 @@ epi_slide_opt <- function(
arrange(.data$time_value)

if (f_from_package == "data.table") {
# If a group contains duplicate time values, `frollmean` will still only
# use the last `k` obs. It isn't looking at dates, it just goes in row
# order. So if the computation is aggregating across multiple obs for the
# same date, `epi_slide_opt` and derivates will produce incorrect results;
# `epi_slide` should be used instead.
if (anyDuplicated(.data_group$time_value) != 0L) {
cli_abort(
c(
"group contains duplicate time values. Using `epi_slide_[opt/mean/sum]` on this
group will result in incorrect results",
"i" = "Please change the grouping structure of the input data so that
each group has non-duplicate time values (e.g. `x %>% group_by(geo_value)
%>% epi_slide_opt(.f = frollmean)`)",
"i" = "Use `epi_slide` to aggregate across groups"
),
class = "epiprocess__epi_slide_opt__duplicate_time_values",
epiprocess__data_group = .data_group,
epiprocess__group_key = .group_key
)
}

# Grouping should ensure that we don't have duplicate time values.
# Completion above should ensure we have at least .window_size rows. Check
# that we don't have more than .window_size rows (or fewer somehow):
if (nrow(.data_group) != length(c(all_dates, pad_early_dates, pad_late_dates))) {
cli_abort(
c(
Expand Down Expand Up @@ -928,7 +928,8 @@ epi_slide_opt <- function(
group_modify(slide_one_grp, ..., .keep = FALSE) %>%
filter(.data$.real) %>%
select(-.real) %>%
arrange_col_canonical()
arrange_col_canonical() %>%
group_by(!!!.x_orig_groups)

if (.all_rows) {
result[!(result$time_value %in% ref_time_values), result_col_names] <- NA
Expand Down
4 changes: 2 additions & 2 deletions man-roxygen/basic-slide-params.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' @param .x An `epi_df` object. If ungrouped, we group by `geo_value` and any
#' columns in `other_keys`. If grouped, we make sure the grouping is by
#' @param .x An `epi_df` object. If ungrouped, we temporarily group by `geo_value`
#' and any columns in `other_keys`. If grouped, we make sure the grouping is by
#' `geo_value` and `other_keys`.
#' @param .window_size The size of the sliding window. The accepted values
#' depend on the type of the `time_value` column in `.x`:
Expand Down
4 changes: 2 additions & 2 deletions man/epi_slide.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/epi_slide_opt.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 38 additions & 0 deletions tests/testthat/test-epi_slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -899,3 +899,41 @@ test_that("epi_slide_opt output naming features", {
class = "epiprocess__epi_slide_opt_new_name_duplicated"
)
})

test_that("epi_slide* output grouping matches input grouping", {
toy_edf <- as_epi_df(bind_rows(list(
tibble(geo_value = 1, age_group = 1, time_value = as.Date("2020-01-01") + 1:10 - 1, value = 1:10),
tibble(geo_value = 1, age_group = 2, time_value = as.Date("2020-01-01") + 1:10 - 1, value = 20:11),
tibble(geo_value = 2, age_group = 2, time_value = as.Date("2020-01-01") + 1:10 - 1, value = 31:40)
)), other_keys = "age_group", as_of = as.Date("2020-01-01") + 20)

# Preserving existing grouping:
expect_equal(
toy_edf %>%
group_by(age_group, geo_value) %>%
epi_slide(value_7dsum = sum(value), .window_size = 7) %>%
group_vars(),
c("age_group", "geo_value")
)
expect_equal(
toy_edf %>%
group_by(age_group, geo_value) %>%
epi_slide_sum(value, .window_size = 7) %>%
group_vars(),
c("age_group", "geo_value")
)

# Removing automatic grouping:
expect_equal(
toy_edf %>%
epi_slide(value_7dsum = sum(value), .window_size = 7) %>%
group_vars(),
character(0)
)
expect_equal(
toy_edf %>%
epi_slide_sum(value, .window_size = 7) %>%
group_vars(),
character(0)
)
})

0 comments on commit f9a8356

Please sign in to comment.