Skip to content

Commit f141d06

Browse files
authored
Fix incorrect sizing for unconstrain_draws (#983)
1 parent 499aa23 commit f141d06

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

R/fit.R

+8-1
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,9 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
561561
#'
562562
unconstrain_draws <- function(files = NULL, draws = NULL,
563563
format = getOption("cmdstanr_draws_format", "draws_array")) {
564+
if (!(format %in% valid_draws_formats())) {
565+
stop("Invalid draws format requested!", call. = FALSE)
566+
}
564567
if (!is.null(files) || !is.null(draws)) {
565568
if (!is.null(files) && !is.null(draws)) {
566569
stop("Either a list of CSV files or a draws object can be passed, not both",
@@ -582,6 +585,8 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
582585
}
583586
draws <- maybe_convert_draws_format(private$draws_, "draws_matrix")
584587
}
588+
589+
chains <- posterior::nchains(draws)
585590

586591
model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
587592
model_variables <- self$runset$args$model_variables
@@ -598,7 +603,9 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
598603
unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws)
599604
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
600605
names(unconstrained) <- repair_variable_names(uncon_names)
601-
maybe_convert_draws_format(unconstrained, format, .nchains = posterior::nchains(draws))
606+
unconstrained$.nchains <- chains
607+
608+
do.call(function(...) { create_draws_format(format, ...) }, unconstrained)
602609
}
603610
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)
604611

R/utils.R

+13
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,19 @@ maybe_convert_draws_format <- function(draws, format, ...) {
426426
)
427427
}
428428

429+
create_draws_format <- function(format, ...) {
430+
format <- sub("^draws_", "", format)
431+
switch(
432+
format,
433+
"array" = posterior::draws_array(...),
434+
"df" = posterior::draws_df(...),
435+
"data.frame" = posterior::draws_df(...),
436+
"list" = posterior::draws_list(...),
437+
"matrix" = posterior::draws_matrix(...),
438+
"rvars" = posterior::draws_rvars(...),
439+
stop("Invalid draws format.", call. = FALSE)
440+
)
441+
}
429442

430443
# convert draws for external packages ------------------------------------------
431444

inst/include/model_methods.cpp

+13-4
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,24 @@ Eigen::VectorXd unconstrain_variables(SEXP ext_model_ptr, Eigen::VectorXd variab
127127
}
128128

129129
// [[Rcpp::export]]
130-
Eigen::MatrixXd unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) {
130+
Rcpp::List unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) {
131131
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
132-
Eigen::MatrixXd unconstrained_draws(variables.cols(), variables.rows());
132+
// Need to do this for the first row to get the correct size of the unconstrained draws
133+
Eigen::VectorXd unconstrained_draw1;
134+
ptr->unconstrain_array(variables.row(0).transpose(), unconstrained_draw1, &Rcpp::Rcout);
135+
std::vector<Eigen::VectorXd> unconstrained_draws(unconstrained_draw1.size());
136+
for (auto&& unconstrained_par : unconstrained_draws) {
137+
unconstrained_par.resize(variables.rows());
138+
}
139+
133140
for (int i = 0; i < variables.rows(); i++) {
134141
Eigen::VectorXd unconstrained_variables;
135142
ptr->unconstrain_array(variables.transpose().col(i), unconstrained_variables, &Rcpp::Rcout);
136-
unconstrained_draws.col(i) = unconstrained_variables;
143+
for (int j = 0; j < unconstrained_variables.size(); j++) {
144+
unconstrained_draws[j](i) = unconstrained_variables(j);
145+
}
137146
}
138-
return unconstrained_draws.transpose();
147+
return Rcpp::wrap(unconstrained_draws);
139148
}
140149

141150
// [[Rcpp::export]]

0 commit comments

Comments
 (0)