diff --git a/DESCRIPTION b/DESCRIPTION index f2f5774..e681ae4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -22,7 +22,8 @@ Imports: hardhat, methods, Rcpp, - utils + utils, + xgboost Suggests: knitr, mvtnorm, @@ -38,4 +39,4 @@ Config/Needs/website: patchwork, ggplot2 Config/testthat/edition: 3 Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.0 diff --git a/NAMESPACE b/NAMESPACE index b14b57c..80d091b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -11,6 +11,7 @@ S3method(rpf,formula) S3method(rpf,matrix) S3method(rpf,recipe) S3method(str,rpf_forest) +export(convert_xgboost_rpf) export(is_purified) export(predict_components) export(purify) @@ -19,7 +20,6 @@ import(checkmate) importFrom(Rcpp,loadModule) importFrom(Rcpp,sourceCpp) importFrom(data.table,":=") -importFrom(data.table,':=') importFrom(data.table,.BY) importFrom(data.table,.EACHI) importFrom(data.table,.GRP) @@ -40,4 +40,5 @@ importFrom(methods,new) importFrom(stats,model.matrix) importFrom(utils,capture.output) importFrom(utils,combn) +importFrom(xgboost,xgb.model.dt.tree) useDynLib(randomPlantedForest, .registration = TRUE) diff --git a/R/convert_xgboost.R b/R/convert_xgboost.R new file mode 100644 index 0000000..97b1dbd --- /dev/null +++ b/R/convert_xgboost.R @@ -0,0 +1,76 @@ + +# Function to get leaf bounds from xgboost +#FIXME: This seems to be the bottleneck, could do it in C++ +get_leaf_bounds <- function(trees, tree,x) { + max_node <- trees[Tree == tree, max(Node)] + num_nodes <- max_node + 1 + lb <- matrix(-Inf, nrow = num_nodes, ncol = ncol(x)) + ub <- matrix(Inf, nrow = num_nodes, ncol = ncol(x)) + for (nn in 0:max_node) { + if (trees[Tree == tree & Node == nn, !is.na(Yes)]) { + left_child <- trees[Tree == tree & Node == nn, Yes] + right_child <- trees[Tree == tree & Node == nn, No] + splitvar <- trees[Tree == tree & Node == nn, Feature_num] + + # Children inherit bounds + ub[left_child + 1, ] <- ub[nn + 1, ] + ub[right_child + 1, ] <- ub[nn + 1, ] + lb[left_child + 1, ] <- lb[nn + 1, ] + lb[right_child + 1, ] <- lb[nn + 1, ] + + # Restrict by new split + ub[left_child + 1, splitvar] <- trees[Tree == tree & Node == nn, Split] + lb[right_child + 1, splitvar] <- trees[Tree == tree & Node == nn, Split] + } + } + + # Return bounds of leaves only + leaves <- trees[Tree == tree & Feature == "Leaf", Node+1] + list(lower = lb[leaves, ], + upper = ub[leaves, ]) +} + +#' Convert xgboost to rpf object +#' +#' @param xg xgboost object +#' @param x data used to train the xgboost model +#' @param y target used to train the xgboost model +#' +#' @return rpf object +#' @importFrom xgboost xgb.model.dt.tree +#' @export +convert_xgboost_rpf <- function(xg, x, y) { + trees <- xgboost::xgb.model.dt.tree(model = xg, use_int_id = TRUE) + trees[, Feature_num := as.integer(factor(Feature, levels = c("Leaf", colnames(x)))) - 1L] + num_trees <- trees[, max(Tree)+1] + + # create a dummy rpf + rpfit <- rpf(x = x, y = y, max_interaction = 0, ntrees = num_trees, splits = 1, + purify = FALSE) + + # Overwrite rpf trees + for (t in seq_len(num_trees)) { + # xgboost adds 0.5 to prediction + rpfit$forest[[t]]$values[[1]][[1]] <- 0.5 + rpfit$forest[[t]]$variables[[2]] <- trees[Tree == t-1 & Feature_num > 0, sort(unique(Feature_num))] + rpfit$forest[[t]]$values[[2]] <- as.list(as.numeric(num_trees)*trees[Tree == t-1 & Feature == "Leaf", Quality]) + + rpfit$forest[[t]]$intervals[[2]] <- rep(rpfit$forest[[t]]$intervals[[1]], length(rpfit$forest[[t]]$values[[2]])) + + # Get leaf bounds + leaf_bounds <- get_leaf_bounds(trees, t-1,x) + leaves <- trees[Tree == t-1 & Feature == "Leaf", Node+1] + for (i in seq_along(leaves)) { + rpfit$forest[[t]]$intervals[[2]][[i]][1, ] <- pmax(rpfit$forest[[t]]$intervals[[2]][[i]][1, ], + leaf_bounds$lower[i, ]) + rpfit$forest[[t]]$intervals[[2]][[i]][2, ] <- pmin(rpfit$forest[[t]]$intervals[[2]][[i]][2, ], + leaf_bounds$upper[i, ]) + } + } + + # Also overwrite C++ forest + rpfit$fit$set_model(rpfit$forest) + + # Return manipulated rpf object + rpfit +} diff --git a/R/utils.R b/R/utils.R index c04fc14..20fbaf5 100644 --- a/R/utils.R +++ b/R/utils.R @@ -76,7 +76,7 @@ pca_order <- function(x, y) { # Sort factor predictors by outcome and re-encode as integer # save original factor levels for prediction step # Used in rpf_bridge() -#' @importFrom data.table .SD ':=' as.data.table +#' @importFrom data.table .SD as.data.table preprocess_predictors_fit <- function(processed) { predictors <- as.data.table(processed$predictors) diff --git a/man/convert_xgboost_rpf.Rd b/man/convert_xgboost_rpf.Rd new file mode 100644 index 0000000..4dae5a8 --- /dev/null +++ b/man/convert_xgboost_rpf.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/convert_xgboost.R +\name{convert_xgboost_rpf} +\alias{convert_xgboost_rpf} +\title{Convert xgboost to rpf object} +\usage{ +convert_xgboost_rpf(xg, x, y) +} +\arguments{ +\item{xg}{xgboost object} + +\item{x}{data used to train the xgboost model} + +\item{y}{target used to train the xgboost model} +} +\value{ +rpf object +} +\description{ +Convert xgboost to rpf object +} diff --git a/randomPlantedForest.Rproj b/randomPlantedForest.Rproj index ef6fa31..e353823 100644 --- a/randomPlantedForest.Rproj +++ b/randomPlantedForest.Rproj @@ -18,5 +18,6 @@ LineEndingConversion: Posix BuildType: Package PackageUseDevtools: Yes +PackageCleanBeforeInstall: No PackageInstallArgs: --no-multiarch --with-keep.source PackageRoxygenize: rd,collate,namespace diff --git a/src/include/rpf.hpp b/src/include/rpf.hpp index 53e8d13..d24b105 100644 --- a/src/include/rpf.hpp +++ b/src/include/rpf.hpp @@ -24,6 +24,7 @@ class RandomPlantedForest void get_parameters(); void set_parameters(StringVector keys, NumericVector values); List get_model(); + void set_model(List& model); virtual ~RandomPlantedForest(){}; bool is_purified(); @@ -57,4 +58,4 @@ class RandomPlantedForest std::multimap> &possible_splits, TreeFamily &curr_family); }; -#endif // RPF_HPP \ No newline at end of file +#endif // RPF_HPP diff --git a/src/lib/rpf.cpp b/src/lib/rpf.cpp index ef0a5e1..277cdd1 100644 --- a/src/lib/rpf.cpp +++ b/src/lib/rpf.cpp @@ -2088,3 +2088,54 @@ List RandomPlantedForest::get_model() } return (model); } + +// FIXME: Add checks for format of model argument +void RandomPlantedForest::set_model(List& model) { + + unsigned int n_families = model.size(); + + for (unsigned int i = 0; i < n_families; ++i) { + List family = model[i]; + List variables = family["variables"]; + List values = family["values"]; + List intervals = family["intervals"]; + unsigned int n_trees = variables.size(); + + TreeFamily& old_family = tree_families[i]; + + old_family.clear(); + + for (unsigned int j = 0; j < n_trees; ++j) { + + // Variables + IntegerVector tree_variables = variables[j]; + std::set tree_variables_set(tree_variables.begin(), tree_variables.end()); + + // Values + std::vector> tree_values = values[j]; + unsigned int n_leaves = tree_values.size(); + + // Intervals + std::vector tree_intervals = intervals[j]; + + std::vector leaves; + for (unsigned int k = 0; k < n_leaves; ++k) { + + Leaf temp; + temp.value = tree_values[k]; + + NumericMatrix leaf_intervals = tree_intervals[k]; + std::vector intervals(leaf_intervals.ncol()); + for (int l = 0; l < feature_size; ++l) { + intervals[l] = Interval{leaf_intervals(0, l), leaf_intervals(1, l)}; + } + + temp.intervals = intervals; + leaves.push_back(temp); + } + + old_family.insert(std::make_pair(tree_variables_set, std::make_shared(DecisionTree(tree_variables_set, leaves)))); + + } + } +} diff --git a/src/randomPlantedForest.cpp b/src/randomPlantedForest.cpp index 9fd3630..106fa61 100644 --- a/src/randomPlantedForest.cpp +++ b/src/randomPlantedForest.cpp @@ -19,6 +19,7 @@ RCPP_MODULE(mod_rpf) .method("get_parameters", &RandomPlantedForest::get_parameters) .method("set_parameters", &RandomPlantedForest::set_parameters) .method("get_model", &RandomPlantedForest::get_model) + .method("set_model", &RandomPlantedForest::set_model) .method("is_purified", &RandomPlantedForest::is_purified); class_("ClassificationRPF")