diff --git a/NAMESPACE b/NAMESPACE index 3195e8d35..cdb8eac23 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,7 @@ # Generated by roxygen2: do not edit by hand S3method("+",bform) +S3method("+",brmsinits) S3method("+",brmsprior) S3method("+",stanvars) S3method(.compute_point_estimate,brmsprep) @@ -36,6 +37,7 @@ S3method(bridge_sampler,brmsfit) S3method(brmsterms,brmsformula) S3method(brmsterms,default) S3method(brmsterms,mvbrmsformula) +S3method(c,brmsinits) S3method(c,brmsprior) S3method(c,stanvars) S3method(coef,brmsfit) @@ -159,6 +161,7 @@ S3method(nsamples,brmsfit) S3method(nuts_params,brmsfit) S3method(nvariables,brmsfit) S3method(pairs,brmsfit) +S3method(par_info,brmsterms) S3method(parnames,brmsfit) S3method(parnames,default) S3method(plot,brmsMarginalEffects) @@ -585,6 +588,7 @@ export(rwiener) export(s) export(sar) export(save_pars) +export(set_inits) export(set_mecor) export(set_nl) export(set_prior) diff --git a/R/brm.R b/R/brm.R index 43c86dc1a..70dbf23aa 100644 --- a/R/brm.R +++ b/R/brm.R @@ -544,6 +544,22 @@ brm <- function(formula, data, family = gaussian(), prior = NULL, normalize = normalize ) + # generate Stan data before compiling the model to avoid + # unnecessary compilations in case of invalid data + sdata <- .standata( + bterms, data = data, prior = prior, data2 = data2, + stanvars = stanvars, threads = threads + ) + + # generate inits + if (is.brmsinits(init)) { + init <- replicate( + chains, + .inits_fun(init, bterms = bterms, data = data, sdata = sdata), + simplify = FALSE + ) + } + # initialize S3 object x <- brmsfit( formula = formula, data = data, data2 = data2, prior = prior, @@ -554,12 +570,6 @@ brm <- function(formula, data, family = gaussian(), prior = NULL, stan_args = nlist(init, silent, control, stan_model_args, ...) ) exclude <- exclude_pars(x) - # generate Stan data before compiling the model to avoid - # unnecessary compilations in case of invalid data - sdata <- .standata( - bterms, data = data, prior = prior, data2 = data2, - stanvars = stanvars, threads = threads - ) if (empty) { # return the brmsfit object with an empty 'fit' slot diff --git a/R/inits.R b/R/inits.R new file mode 100644 index 000000000..15aa550bd --- /dev/null +++ b/R/inits.R @@ -0,0 +1,213 @@ +#' Init definitions for **brms** models +#' +#' Define how initial values for specific parameters are generated. +#' +#' @inheritParams set_prior +#' @param distribution A character string specifying the distribution of the initial values +#' +#' @return An object of class `brmsinits` to be used in the `init` argument of [brm] +#' @export +#' +#' @examples +#' \dontrun{ +#' inits <- set_inits("normal(0, 1)", class = "Intercept", coef = "mu") + +#' set_inits("uniform(-1, 1)", class = "b", coef = "mu") +#' # use the inits in a brm call +#' fit <- brm(count ~ Trt + zAge, epilepsy, poisson(), init = inits) +#' } +set_inits <- function(distribution, class = "b", coef = "", group = "", + dpar = "", nlpar = "") { + input <- nlist(distribution, class, coef, group, dpar, nlpar) + input <- try(as.data.frame(input), silent = TRUE) + if (is_try_error(input)) { + stop2("Processing arguments of 'set_inits' has failed:\n", input) + } + out <- vector("list", nrow(input)) + for (i in seq_along(out)) { + out[[i]] <- do_call(.set_inits, input[i, ]) + } + Reduce("+", out) +} + + +# validate arguments passed to 'set_inits' +.set_inits <- function(distribution, class, coef, group, + dpar, nlpar) { + distribution <- as_one_character(distribution) + class <- as_one_character(class) + group <- as_one_character(group) + coef <- as_one_character(coef) + dpar <- as_one_character(dpar) + nlpar <- as_one_character(nlpar) + if (dpar == "mu") { + # distributional parameter 'mu' is currently implicit + dpar <- "" + } + out <- data.frame(distribution, class, coef, group, dpar, nlpar) + class(out) <- c("brmsinits", "data.frame") + out +} + +# Internal function for generating a list of inits to pass to stan from a +# brmsinits object created from set_inits() +# @param binits A brmsinits object +# @param bterms A brmsterms object +# @param data The data used in the model +# @param sdata The stan data list +# @return A list of inits to pass to stan +.inits_fun <- function(binits, bterms, data, sdata) { + # TODO: check if inits are properly specified (similar to how the priors are checked) + pars <- paste0(binits$dpar, binits$nlpar) + sep <- ifelse(pars == "", "", "_") + # temporary - works for Intercept and b, but not for sd, z, etc; needs to be generalized by using code from .stancode + binits$stanpars <- paste0(binits$class, sep, pars) + # get the information typically used in the parameters block of stancode + info <- par_info(bterms, data) + + dims <- sdata[info$b_dim_name] + dims <- ifelse(is.na(info$b_dim_name), 1, dims) + prefixes <- ifelse(info$b_type == "real", "", "array(") + suffixes <- ifelse(info$b_type == "real", "", ")") # here we would add dimensions as well + + # construct the call for generating inits for each row of binits + out <- list() + for (i in 1:nrow(binits)) { + idx <- which(info$b_par == binits$stanpars[[i]]) + pinfo <- info[idx, ] + dist <- parse_dist(binits$distribution[[i]]) + args <- paste0(dist$args, collapse = ", ") + prefix <- prefixes[idx] + suffix <- suffixes[idx] + dim <- dims[[idx]] + call <- glue('{prefix}{dist$fun}({dim}, {args}){suffix}') + call <- parse(text = call) + out[[binits$stanpars[[i]]]] <- eval(call) + } + + out +} + + +# combine multiple brmsinits objects into one brmsinits (code almost identical to +# c.brmsprior) +#' @export +c.brmsinits <- function(x, ..., replace = FALSE) { + dots <- list(...) + if (all(sapply(dots, is.brmsinits))) { + replace <- as_one_logical(replace) + # don't use 'c()' here to avoid creating a recursion + out <- do_call(rbind, list(x, ...)) + if (replace) { + # update duplicated inits + out <- unique(out, fromLast = TRUE) + } + } else { + if (length(dots)) { + stop2("Cannot add '", class(dots[[1]])[1], "' objects to the inits") + } + out <- c(as.data.frame(x)) + } + out +} + +#' @export +"+.brmsinits" <- function(e1, e2) { + if (is.null(e2)) { + return(e1) + } + if (!is.brmsinits(e2)) { + stop2("Cannot add '", class(e2)[1], "' objects to the inits") + } + c(e1, e2) +} + +is.brmsinits <- function(x) { + inherits(x, "brmsinits") +} + + +# takes a character vector like 'normal(0, 1)' and returns a list with the +# r* function and its arguments +# to do - more careful checks of the passed format? +parse_dist <- function(x) { + x <- as_one_character(x) + x <- parse(text = x)[[1]] + dist <- as.character(x[[1]]) + args <- as.list(x[-1]) + args <- lapply(args, function(x) { + tmp <- as.character(x) + as.numeric(collapse(tmp)) + }) + fun <- to_rfun(dist) + nlist(fun, args) +} + +# takes a character string and returns the corresponding r random generation +# function +to_rfun <- function(x) { + x <- as_one_character(x) + # TODO expandlist + dists <- c(normal = 'norm', poisson = 'pois', binomial = 'binom', + inv_gamma = 'invgamma', lognormal = 'lnorm', exponential = 'exp', + uniform = 'unif') + out <- dists[x] + if (is.null(out) || is.na(out)) { + out <- x + } + paste0("r", out) +} + +par_info <- function(bterms, data, ...) { + UseMethod("par_info") +} + +#' @export +par_info.brmsterms <- function(bterms, data, ...) { + out <- list() + for (par in names(bterms$dpars)) { + info <- par_info_fe(bterms$dpars[[par]], data) + info <- as.data.frame(info) + out <- rbind(out, info) + } + out +} + +# internal function for extracting information about the population-effects parameters +# that is typically part of the parameters block of the stan code +# @param bterms A brmsterms object +# @param data The data used in the model +# @return A list with the following elements: +# - b_type: the type of the parameter (real, vector, array) +# - b_dim_name: the name of the dimension of the parameter (should match in standata) +# - b_par: the name of the parameter in stan +# @details +# if a parameter is described as vector[Kc_sigma] b_sigma, the output will be: +# list(b_type = "vector", b_dim_name = "Kc_sigma", b_par = "b_sigma") +par_info_fe <- function(bterms, data) { + out <- list() + family <- bterms$family + fixef <- colnames(data_fe(bterms, data)$X) + center_X <- stan_center_X(bterms) + ct <- str_if(center_X, "c") + # remove the intercept from the design matrix? + if (center_X) { + fixef <- setdiff(fixef, "Intercept") + } + px <- check_prefix(bterms) + p <- usc(combine_prefix(px)) + resp <- usc(px$resp) + + out <- list() + if (length(fixef)) { + out$b_type <- "vector" + out$b_dim_name <- glue("K{ct}{p}") + out$b_par <- glue("b{p}") + } + + if (center_X) { + c(out$b_type) <- "real" + c(out$b_dim_name) <- NA + c(out$b_par) <- glue("Intercept{p}") + } + out +} \ No newline at end of file diff --git a/man/set_inits.Rd b/man/set_inits.Rd new file mode 100644 index 000000000..2ba3e70d8 --- /dev/null +++ b/man/set_inits.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/inits.R +\name{set_inits} +\alias{set_inits} +\title{Init definitions for **brms** models} +\usage{ +set_inits( + distribution, + class = "b", + coef = "", + group = "", + dpar = "", + nlpar = "" +) +} +\arguments{ +\item{distribution}{A character string specifying the distribution of the initial values} + +\item{class}{The parameter class. Defaults to \code{"b"} +(i.e. population-level effects). +See 'Details' for other valid parameter classes.} + +\item{coef}{Name of the coefficient within the parameter class.} + +\item{group}{Grouping factor for group-level parameters.} + +\item{dpar}{Name of a distributional parameter. +Only used in distributional models.} + +\item{nlpar}{Name of a non-linear parameter. +Only used in non-linear models.} +} +\value{ +An object of class `brmsinits` to be used in the `init` argument of [brm] +} +\description{ +Define how initial values for specific parameters are generated. +} +\examples{ +\dontrun{ +inits <- set_inits("normal(0, 1)", class = "Intercept", coef = "mu") + + set_inits("uniform(-1, 1)", class = "b", coef = "mu") +# use the inits in a brm call +fit <- brm(count ~ Trt + zAge, epilepsy, poisson(), init = inits) +} +} diff --git a/tests/local/tests.set_inits.R b/tests/local/tests.set_inits.R new file mode 100644 index 000000000..d2e27ed94 --- /dev/null +++ b/tests/local/tests.set_inits.R @@ -0,0 +1,38 @@ +source("tests/local/setup_tests_local.R") + +data <- epilepsy +data$cat <- factor(sample(1:5, nrow(data), replace = TRUE)) + +test_that('gaussian model runs with set_inits', { + formula <- bf(count ~ cat + Age, + sigma ~ cat + Age) + + inits <- set_inits('normal(0, 1)', class = "Intercept", dpar = "mu") + + set_inits('uniform(-0.1, 0.1)', class = "b", dpar = "sigma") + + out <- capture_messages(fit <- brm(formula, data = data, init = inits, refresh = 0, chains = 2)) + out <- paste0(out, collapse = "\n") + expect_true(grepl("Missing init values for the following parameters:", out)) + expect_true(grepl(": b, Intercept_sigma\n", out)) + fit_init <- fit$stan_args$init + expect_length(fit_init, 2) + expect_equal(names(fit_init[[1]]), c("Intercept", "b_sigma")) + expect_range(fit_init[[1]]$b_sigma, -0.1, 0.1) +}) + + +test_that('poisson model runs with set_inits', { + formula <- bf(count ~ cat + Age) + + inits <- set_inits('normal(0, 1)', class = "Intercept", dpar = "mu") + + set_inits('uniform(-0.1, 0.1)', class = "b", dpar = "mu") + + out <- capture_messages(fit <- brm(formula, data = data, family = poisson(), + init = inits, refresh = 0, chains = 2)) + out <- paste0(out, collapse = "\n") + expect_false(grepl("Missing init values for the following parameters:", out)) + fit_init <- fit$stan_args$init + expect_length(fit_init, 2) + expect_equal(names(fit_init[[1]]), c("Intercept", "b")) + expect_range(fit_init[[1]]$b, -0.1, 0.1) +}) diff --git a/tests/testthat/test-inits.R b/tests/testthat/test-inits.R new file mode 100644 index 000000000..8d37512c5 --- /dev/null +++ b/tests/testthat/test-inits.R @@ -0,0 +1,54 @@ +test_that('set_inits produces the correct format', { + res <- set_inits('normal(0, 1)', class = "Intercept", dpar = "mu") + expect_s3_class(res, "data.frame") + expect_s3_class(res, "brmsinits") + out <- as.data.frame(res) + expect_equal(out, data.frame(distribution = "normal(0, 1)", + class = "Intercept", + coef = "", + group = "", + dpar = "", + nlpar = "")) + + res2 <- set_inits('normal(0, 1)', class = "sd", dpar = "sigma") + out <- res + res2 + out <- as.data.frame(out) + expect_equal(out, data.frame(distribution = c("normal(0, 1)", "normal(0, 1)"), + class = c("Intercept", "sd"), + coef = c("", ""), + group = c("", ""), + dpar = c("", "sigma"), + nlpar = c("", ""))) +}) + + +test_that('parse_dist works', { + d <- 'normal(0, 1)' + res <- parse_dist(d) + expect_equal(res, list(fun = "rnorm", args = list(0, 1))) + + d <- 'uniform(-1, 1.5)' + res <- parse_dist(d) + expect_equal(res, list(fun = "runif", args = list(-1, 1.5))) +}) + + +test_that('.inits_fun works', { + data <- epilepsy + data$cat <- factor(sample(1:5, nrow(data), replace = TRUE)) + formula <- formula <- bf(count ~ cat + Age, + sigma ~ cat + Age) + bterms <- brmsterms(formula) + sdata <- standata(formula, data = data) + + inits <- set_inits('normal(0, 1)', class = "Intercept", dpar = "mu") + + set_inits('uniform(-1, 1)', class = "b", dpar = "sigma") + + out <- .inits_fun(inits, bterms = bterms, data = data, sdata = sdata) + expect_type(out, "list") + expect_equal(names(out), c("Intercept", "b_sigma")) + expect_length(out$Intercept, 1) + expect_equal(class(out$Intercept), "numeric") + expect_length(out$b_sigma, 5) + expect_equal(class(out$b_sigma), "array") +})