Skip to content

Commit

Permalink
Simulate infections (#557)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk authored Feb 20, 2024
1 parent 3c1ad81 commit 4c1f460
Show file tree
Hide file tree
Showing 10 changed files with 454 additions and 44 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ importFrom(checkmate,assert_names)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_path_for_output)
importFrom(checkmate,assert_string)
importFrom(checkmate,assert_subset)
importFrom(checkmate,test_data_frame)
importFrom(checkmate,test_numeric)
importFrom(data.table,":=")
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* The `fixed` argument to `dist_spec` has been deprecated and replaced by a `fix_dist()` function. By @sbfnk in #503 and reviewed by @seabbs.
* Updated `estimate_infections()` so that rather than imputing missing data, it now skips these data points in the likelihood. This is a breaking change as it alters the behaviour of the model when dates are missing from a time series but are known to be zero. We recommend that users check their results when updating to this version but expect this to in most cases improve performance. By @seabbs in #528 and reviewed by @sbfnk.
* `simulate_infections` has been renamed to `forecast_infections` in line with `simulate_secondary` and `forecast_secondary`. The terminology is: a forecast is done from a fit to existing data, a simulation from first principles. By @sbfnk in #544 and reviewed by @seabbs.
* A new `simulate_infections` function has been added that can be used to simulate from the model from given initial conditions and parameters. By @sbfnk in #557 and reviewed by @jamesmbaazam.

## Documentation

Expand Down
43 changes: 23 additions & 20 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,28 +181,31 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
samples,
reported_dates
)
if (data$estimate_r == 1) {
out$R <- extract_parameter(
"R",
samples,
reported_dates
)
if (data$bp_n > 0) {
out$breakpoints <- extract_parameter(
"bp_effects",
if ("estimate_r" %in% names(data)) {
if (data$estimate_r == 1) {
out$R <- extract_parameter(
"R",
samples,
1:data$bp_n
reported_dates
)
if (data$bp_n > 0) {
out$breakpoints <- extract_parameter(
"bp_effects",
samples,
1:data$bp_n
)
out$breakpoints <- out$breakpoints[
,
strat := date
][, c("time", "date") := NULL]
}
} else {
out$R <- extract_parameter(
"gen_R",
samples,
reported_dates
)
out$breakpoints <- out$breakpoints[,
strat := date][, c("time", "date") := NULL
]
}
} else {
out$R <- extract_parameter(
"gen_R",
samples,
reported_dates
)
}
out$growth_rate <- extract_parameter(
"r",
Expand Down Expand Up @@ -243,7 +246,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
value.V1 := NULL
]
}
if (data$obs_scale_sd > 0) {
if ("obs_scale_sd" %in% names(data) && data$obs_scale_sd > 0) {
out$fraction_observed <- extract_static_parameter("frac_obs", samples)
out$fraction_observed <- out$fraction_observed[, value := value.V1][,
value.V1 := NULL
Expand Down
209 changes: 199 additions & 10 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,208 @@
#' Deprecated; use [forecast_infections()] instead
#' Simulate infections using the renewal equation
#'
#' Calling this function passes all arguments to [forecast_infections()]
#' @description `r lifecycle::badge("deprecated")`
#' @param ... Arguments to be passed to [forecast_infections()]
#' @return the result of [forecast_infections()]
#' Simulations are done from given initial infections and, potentially
#' time-varying, reproduction numbers. Delays and parameters of the observation
#' model can be specified using the same options as in [estimate_infections()].
#'
#' In order to simulate, all parameters that are specified such as the mean and
#' standard deviation of delays or observation scaling, must be fixed.
#' Uncertain parameters are not allowed.
#'
#' A previous function called [simulate_infections()] that simulates from a
#' given model fit has been renamed [forecast_infections()]. Using
#' [simulate_infections()] with existing estimates is now deprecated. This
#' option will be removed in version 2.1.0.
#' @param R a data frame of reproduction numbers (column `R`) by date (column
#' `date`). Column `R` must be numeric and `date` must be in date format. If
#' not all days between the first and last day in the `date` are present,
#' it will be assumed that R stays the same until the next given date.
#' @param initial_infections numeric; the initial number of infections.
#' @param day_of_week_effect either `NULL` (no day of the week effect) or a
#' numerical vector of length specified in [obs_opts()] as `week_length`
#' (default: 7) if `week_effect` is set to TRUE. Each element of the vector
#' gives the weight given to reporting on this day (normalised to 1).
#' The default is `NULL`.
#' @param estimates deprecated; use [forecast_infections()] instead
#' @param ... deprecated; only included for backward compatibility
#' @inheritParams estimate_infections
#' @inheritParams rt_opts
#' @inheritParams stan_opts
#' @importFrom lifecycle deprecate_warn
#' @importFrom checkmate assert_data_frame assert_date assert_numeric
#' assert_subset
#' @importFrom data.table data.table merge.data.table nafill rbindlist
#' @return A data.table of simulated infections (variable `infections`) and
#' reported cases (variable `reported_cases`) by date.
#' @author Sebastian Funk
#' @export
simulate_infections <- function(...) {
#' @examples
#' \donttest{
#' R <- data.frame(
#' date = seq.Date(as.Date("2023-01-01"), length.out = 14, by = "day"),
#' R = c(rep(1.2, 7), rep(0.8, 7))
#' )
#' sim <- simulate_infections(
#' R = R,
#' initial_infections = 100,
#' generation_time = generation_time_opts(
#' fix_dist(example_generation_time)
#' ),
#' delays = delay_opts(fix_dist(example_reporting_delay)),
#' obs = obs_opts(family = "poisson")
#' )
#' }
simulate_infections <- function(estimates, R, initial_infections,
day_of_week_effect = NULL,
generation_time = generation_time_opts(),
delays = delay_opts(),
truncation = trunc_opts(),
obs = obs_opts(),
CrIs = c(0.2, 0.5, 0.9),
backend = "rstan",
pop = 0, ...) {
## deprecated usage
if (!missing(estimates)) {
deprecate_warn(
"2.0.0",
"simulate_infections()",
"simulate_infections(estimates)",
"forecast_infections()",
"A new [simulate_infections()] function for simulating from given ",
"parameters is planned for implementation in the future."
details = paste0(
"This `estimates` option will be removed from [simulate_infections()] ",
"in version 2.1.0."
)
)
return(forecast_infections(estimates = estimates, ...))
}

## check inputs
assert_data_frame(R, any.missing = FALSE)
assert_subset(colnames(R), c("date", "R"))
assert_date(R$date)
assert_numeric(R$R, lower = 0)
assert_numeric(initial_infections, lower = 0)
assert_numeric(day_of_week_effect, lower = 0, null.ok = TRUE)
assert_numeric(pop, lower = 0)
assert_class(delays, "delay_opts")
assert_class(obs, "obs_opts")
assert_class(generation_time, "generation_time_opts")

## create R for all dates modelled
all_dates <- data.table(date = seq.Date(min(R$date), max(R$date), by = "day"))
R <- merge.data.table(all_dates, R, by = "date", all.x = TRUE)
R <- R[, R := nafill(R, type = "locf")]
## remove any initial NAs
R <- R[!is.na(R)]

seeding_time <- get_seeding_time(delays, generation_time)
if (seeding_time > 1) {
## estimate initial growth from initial reproduction number if seeding time
## is greater than 1
initial_growth <- (R$R[1] - 1) / mean(generation_time)
} else {
initial_growth <- numeric(0)
}

data <- list(
n = 1,
t = nrow(R) + seeding_time,
seeding_time = seeding_time,
future_time = 0,
initial_infections = array(log(initial_infections), dim = c(1, 1)),
initial_growth = array(initial_growth, dim = c(1, length(initial_growth))),
R = array(R$R, dim = c(1, nrow(R))),
pop = pop
)

data <- c(data, create_stan_delays(
gt = generation_time,
delay = delays,
trunc = truncation
))

if ((length(data$delay_mean_sd) > 0 && any(data$delay_mean_sd > 0)) ||
(length(data$delay_sd_sd) > 0 && any(data$delay_sd_sd > 0))) {
stop(
"Cannot simulate from uncertain parameters. Use the [fix_dist()] ",
"function to set the parameters of uncertain distributions either the ",
"mean or a randomly sampled value"
)
forecast_infections(...)
}
data$delay_mean <- array(
data$delay_mean_mean, dim = c(1, length(data$delay_mean_mean))
)
data$delay_sd <- array(
data$delay_sd_mean, dim = c(1, length(data$delay_sd_mean))
)
data$delay_mean_sd <- NULL
data$delay_sd_sd <- NULL

data <- c(data, create_obs_model(
obs, dates = R$date
))

if (data$obs_scale_sd > 0) {
stop(
"Cannot simulate from uncertain observation scaling; use fixed scaling ",
"instead."
)
}
if (data$obs_scale) {
data$frac_obs <- array(data$obs_scale_mean, dim = c(1, 1))
} else {
data$frac_obs <- array(dim = c(1, 0))
}
data$obs_scale_mean <- NULL
data$obs_scale_sd <- NULL

if (obs$family == "negbin") {
if (data$phi_sd > 0) {
stop(
"Cannot simulate from uncertain overdispersion; use fixed ",
"overdispersion instead."
)
}
data$rep_phi <- array(data$phi_mean, dim = c(1, 1))
} else {
data$rep_phi <- array(dim = c(1, 0))
}
data$phi_mean <- NULL
data$phi_sd <- NULL

## day of week effect
if (is.null(day_of_week_effect)) {
day_of_week_effect <- rep(1, data$week_effect)
}

day_of_week_effect <- day_of_week_effect / sum(day_of_week_effect)
data$day_of_week_simplex <- array(
day_of_week_effect, dim = c(1, data$week_effect)
)

# Create stan arguments
stan <- stan_opts(backend = backend, chains = 1, samples = 1, warmup = 1)
args <- create_stan_args(
stan, data = data, fixed_param = TRUE, model = "simulate_infections",
verbose = FALSE
)

## simulate
sim <- fit_model(args, id = "simulate_infections")

## join batches
dates <- c(
seq(min(R$date) - seeding_time, min(R$date) - 1, by = "day"),
R$date
)
out <- extract_parameter_samples(sim, data,
reported_inf_dates = dates,
reported_dates = dates[-(1:seeding_time)],
drop_length_1 = TRUE
)

out <- rbindlist(out[c("infections", "reported_cases")], idcol = "variable")
out <- out[, c("sample", "parameter", "time") := NULL]

return(out[])
}

#' Forecast infections from a given fit and trajectory of the time-varying
Expand Down
4 changes: 2 additions & 2 deletions inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
array[seeding_time ? n : 0, 1] real initial_infections; // initial logged infections
array[seeding_time > 1 ? n : 0, 1] real initial_growth; //initial growth
array[n, 1] real initial_infections; // initial logged infections
array[n, seeding_time > 1 ? 1 : 0] real initial_growth; //initial growth

matrix[n, t - seeding_time] R; // reproduction number
int pop; // susceptible population
Expand Down
20 changes: 17 additions & 3 deletions inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ generated quantities {
to_vector(infections[i]), delay_rev_pmf, seeding_time)
);
} else {
reports[i] = to_row_vector(infections[(seeding_time + 1):t]);
reports[i] = to_row_vector(
infections[i, (seeding_time + 1):t]
);
}

// weekly reporting effect
Expand All @@ -72,6 +74,18 @@ generated quantities {
day_of_week_effect(to_vector(reports[i]), day_of_week,
to_vector(day_of_week_simplex[i])));
}
// truncate near time cases to observed reports
if (trunc_id) {
vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf(
trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_mean[i], delay_sd[i], delay_dist,
0, 1, 1
);
reports[i] = to_row_vector(truncate(
to_vector(reports[i]), trunc_rev_cmf, 0)
);
}
// scale observations
if (obs_scale) {
reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1]));
Expand All @@ -81,8 +95,8 @@ generated quantities {
to_vector(reports[i]), rep_phi[i], model_type
);
{
real gt_mean = rev_pmf_mean(gt_rev_pmf, 1);
real gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean);
real gt_mean = rev_pmf_mean(gt_rev_pmf, 0);
real gt_var = rev_pmf_var(gt_rev_pmf, 0, gt_mean);
r[i] = R_to_growth(to_vector(R[i]), gt_mean, gt_var);
}
}
Expand Down
Loading

0 comments on commit 4c1f460

Please sign in to comment.