Skip to content

Creating a BART sampling loop dispatched completely through C++ #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
308663b
Updated some docs
andrewherren Jul 12, 2024
d450b55
Updated with a C++ only version of the classical versions of BART/XBA…
andrewherren Jul 12, 2024
e607c79
Fixed include issue
andrewherren Jul 12, 2024
861f4bc
Wrapped C++ sampling loop in R function
andrewherren Jul 12, 2024
4249f38
Updated BART classes to use unique pointer to ForestContainer (for ea…
andrewherren Jul 15, 2024
ff32710
Updated R wrapper around templated BARTDispatcher class
andrewherren Jul 15, 2024
5e982ca
Broadening scope of the C++ only sampling loop
andrewherren Jul 15, 2024
75d9480
Updated C++ sampling loop to include more complete feature set
andrewherren Jul 18, 2024
698ec70
Updated BART function
andrewherren Jul 18, 2024
44237db
Added max_depth control to MCMC and GFR samplers
andrewherren Jul 19, 2024
64890f5
Added back the "streamlined" C++ loop for comparison
andrewherren Jul 19, 2024
d90d74f
Debug script to compare implementations of the BART sampling loop
andrewherren Jul 19, 2024
f673cb5
Added functions to inspect tree depth
andrewherren Jul 20, 2024
78606c1
Fixed bug in "specialized" BART loop
andrewherren Jul 20, 2024
f201db6
Fixed max_depth bug
andrewherren Jul 20, 2024
1facf03
Merge main branch into sampler_loop branch
andrewherren Jul 24, 2024
449b3f0
Updated samplers and tests
andrewherren Jul 24, 2024
be90ccc
Updated pybind initializer for ForestSamplerCpp
andrewherren Jul 24, 2024
057d905
Merge branch 'main' into sampler_loop
andrewherren Jul 24, 2024
ee05f22
Updated R package docs
andrewherren Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/build)
file(
GLOB
SOURCES
src/bart.cpp
src/container.cpp
src/cutpoint_candidates.cpp
src/data.cpp
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ S3method(getRandomEffectSamples,bartmodel)
S3method(getRandomEffectSamples,bcf)
S3method(predict,bartmodel)
S3method(predict,bcf)
export(average_max_depth_bart_generalized)
export(average_max_depth_bart_specialized)
export(bart)
export(bart_cpp_loop_generalized)
export(bart_cpp_loop_specialized)
export(bcf)
export(computeForestKernels)
export(computeForestLeafIndices)
Expand Down
646 changes: 645 additions & 1 deletion R/bart.R

Large diffs are not rendered by default.

44 changes: 44 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,49 @@
# Generated by cpp11: do not edit by hand

run_bart_cpp_basis_test_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) {
.Call(`_stochtree_run_bart_cpp_basis_test_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale)
}

run_bart_cpp_basis_test_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) {
.Call(`_stochtree_run_bart_cpp_basis_test_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var)
}

run_bart_cpp_basis_notest_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) {
.Call(`_stochtree_run_bart_cpp_basis_notest_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale)
}

run_bart_cpp_basis_notest_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) {
.Call(`_stochtree_run_bart_cpp_basis_notest_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var)
}

run_bart_cpp_nobasis_test_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) {
.Call(`_stochtree_run_bart_cpp_nobasis_test_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale)
}

run_bart_cpp_nobasis_test_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) {
.Call(`_stochtree_run_bart_cpp_nobasis_test_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var)
}

run_bart_cpp_nobasis_notest_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) {
.Call(`_stochtree_run_bart_cpp_nobasis_notest_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale)
}

run_bart_cpp_nobasis_notest_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) {
.Call(`_stochtree_run_bart_cpp_nobasis_notest_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var)
}

run_bart_specialized_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth) {
.Call(`_stochtree_run_bart_specialized_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth)
}

average_max_depth_bart_generalized_cpp <- function(bart_result) {
.Call(`_stochtree_average_max_depth_bart_generalized_cpp`, bart_result)
}

average_max_depth_bart_specialized_cpp <- function(bart_result) {
.Call(`_stochtree_average_max_depth_bart_specialized_cpp`, bart_result)
}

create_forest_dataset_cpp <- function() {
.Call(`_stochtree_create_forest_dataset_cpp`)
}
Expand Down
6 changes: 3 additions & 3 deletions R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ ForestSamples <- R6::R6Class(
dim(output) <- c(n_trees, num_features, n_samples)
return(output)
},

#' @description
#' Maximum depth of a specific tree in a specific ensemble in a `ForestContainer` object
#' @param ensemble_num Ensemble number
Expand All @@ -271,15 +271,15 @@ ForestSamples <- R6::R6Class(
ensemble_tree_max_depth = function(ensemble_num, tree_num) {
return(ensemble_tree_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num, tree_num))
},

#' @description
#' Average the maximum depth of each tree in a given ensemble in a `ForestContainer` object
#' @param ensemble_num Ensemble number
#' @return Average maximum depth
average_ensemble_max_depth = function(ensemble_num) {
return(ensemble_average_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num))
},

#' @description
#' Average the maximum depth of each tree in each ensemble in a `ForestContainer` object
#' @return Average maximum depth
Expand Down
3 changes: 2 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ ForestModel <- R6::R6Class(
#' @param alpha Root node split probability in tree prior
#' @param beta Depth prior penalty in tree prior
#' @param min_samples_leaf Minimum number of samples in a tree leaf
#' @param max_depth Maximum depth that any tree can reach
#' @param max_depth Maximum depth of any tree in an ensemble. Default: `-1`.
#' @return A new `ForestModel` object.
initialize = function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth = -1) {
stopifnot(!is.null(forest_dataset$data_ptr))
Expand Down Expand Up @@ -116,6 +116,7 @@ createRNG <- function(random_seed = -1){
#' @param alpha Root node split probability in tree prior
#' @param beta Depth prior penalty in tree prior
#' @param min_samples_leaf Minimum number of samples in a tree leaf
#' @param max_depth Maximum depth of any tree in an ensemble
#'
#' @return `ForestModel` object
#' @export
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# StochTree
# StochTree

[![C++ Tests](https://github.com/StochasticTree/stochtree/actions/workflows/cpp-test.yml/badge.svg)](https://github.com/StochasticTree/stochtree/actions/workflows/cpp-test.yml)
[![Python Tests](https://github.com/StochasticTree/stochtree/actions/workflows/python-test.yml/badge.svg)](https://github.com/StochasticTree/stochtree/actions/workflows/python-test.yml)
Expand All @@ -8,7 +9,7 @@ Software for building stochastic tree ensembles (i.e. BART, XBART) for supervise

# Getting Started

`stochtree` is composed of a C++ "core" and R / Python interfaces to that core.
`stochtree` is composed of a C++ "core" and R / Python interfaces to that core.
Details on installation and use are available below:

* [Python](#python-package)
Expand Down
4 changes: 4 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ reference:
- createForestKernel
- CppRNG
- createRNG
- average_max_depth_bart_generalized
- average_max_depth_bart_specialized
- bart_cpp_loop_generalized
- bart_cpp_loop_specialized

- subtitle: Random Effects
desc: >
Expand Down
Loading
Loading