Skip to content

[Experimental] Convert xgboost to rpf objects #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ Imports:
hardhat,
methods,
Rcpp,
utils
utils,
xgboost
Suggests:
knitr,
mvtnorm,
Expand All @@ -38,4 +39,4 @@ Config/Needs/website: patchwork, ggplot2
Config/testthat/edition: 3
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.0
3 changes: 2 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ S3method(rpf,formula)
S3method(rpf,matrix)
S3method(rpf,recipe)
S3method(str,rpf_forest)
export(convert_xgboost_rpf)
export(is_purified)
export(predict_components)
export(purify)
Expand All @@ -19,7 +20,6 @@ import(checkmate)
importFrom(Rcpp,loadModule)
importFrom(Rcpp,sourceCpp)
importFrom(data.table,":=")
importFrom(data.table,':=')
importFrom(data.table,.BY)
importFrom(data.table,.EACHI)
importFrom(data.table,.GRP)
Expand All @@ -40,4 +40,5 @@ importFrom(methods,new)
importFrom(stats,model.matrix)
importFrom(utils,capture.output)
importFrom(utils,combn)
importFrom(xgboost,xgb.model.dt.tree)
useDynLib(randomPlantedForest, .registration = TRUE)
76 changes: 76 additions & 0 deletions R/convert_xgboost.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

# Function to get leaf bounds from xgboost
#FIXME: This seems to be the bottleneck, could do it in C++
get_leaf_bounds <- function(trees, tree,x) {
max_node <- trees[Tree == tree, max(Node)]
num_nodes <- max_node + 1
lb <- matrix(-Inf, nrow = num_nodes, ncol = ncol(x))
ub <- matrix(Inf, nrow = num_nodes, ncol = ncol(x))
for (nn in 0:max_node) {
if (trees[Tree == tree & Node == nn, !is.na(Yes)]) {
left_child <- trees[Tree == tree & Node == nn, Yes]
right_child <- trees[Tree == tree & Node == nn, No]
splitvar <- trees[Tree == tree & Node == nn, Feature_num]

# Children inherit bounds
ub[left_child + 1, ] <- ub[nn + 1, ]
ub[right_child + 1, ] <- ub[nn + 1, ]
lb[left_child + 1, ] <- lb[nn + 1, ]
lb[right_child + 1, ] <- lb[nn + 1, ]

# Restrict by new split
ub[left_child + 1, splitvar] <- trees[Tree == tree & Node == nn, Split]
lb[right_child + 1, splitvar] <- trees[Tree == tree & Node == nn, Split]
}
}

# Return bounds of leaves only
leaves <- trees[Tree == tree & Feature == "Leaf", Node+1]
list(lower = lb[leaves, ],
upper = ub[leaves, ])
}

#' Convert xgboost to rpf object
#'
#' @param xg xgboost object
#' @param x data used to train the xgboost model
#' @param y target used to train the xgboost model
#'
#' @return rpf object
#' @importFrom xgboost xgb.model.dt.tree
#' @export
convert_xgboost_rpf <- function(xg, x, y) {
trees <- xgboost::xgb.model.dt.tree(model = xg, use_int_id = TRUE)
trees[, Feature_num := as.integer(factor(Feature, levels = c("Leaf", colnames(x)))) - 1L]
num_trees <- trees[, max(Tree)+1]

# create a dummy rpf
rpfit <- rpf(x = x, y = y, max_interaction = 0, ntrees = num_trees, splits = 1,
purify = FALSE)

# Overwrite rpf trees
for (t in seq_len(num_trees)) {
# xgboost adds 0.5 to prediction
rpfit$forest[[t]]$values[[1]][[1]] <- 0.5
rpfit$forest[[t]]$variables[[2]] <- trees[Tree == t-1 & Feature_num > 0, sort(unique(Feature_num))]
rpfit$forest[[t]]$values[[2]] <- as.list(as.numeric(num_trees)*trees[Tree == t-1 & Feature == "Leaf", Quality])

rpfit$forest[[t]]$intervals[[2]] <- rep(rpfit$forest[[t]]$intervals[[1]], length(rpfit$forest[[t]]$values[[2]]))

# Get leaf bounds
leaf_bounds <- get_leaf_bounds(trees, t-1,x)
leaves <- trees[Tree == t-1 & Feature == "Leaf", Node+1]
for (i in seq_along(leaves)) {
rpfit$forest[[t]]$intervals[[2]][[i]][1, ] <- pmax(rpfit$forest[[t]]$intervals[[2]][[i]][1, ],
leaf_bounds$lower[i, ])
rpfit$forest[[t]]$intervals[[2]][[i]][2, ] <- pmin(rpfit$forest[[t]]$intervals[[2]][[i]][2, ],
leaf_bounds$upper[i, ])
}
}

# Also overwrite C++ forest
rpfit$fit$set_model(rpfit$forest)

# Return manipulated rpf object
rpfit
}
2 changes: 1 addition & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pca_order <- function(x, y) {
# Sort factor predictors by outcome and re-encode as integer
# save original factor levels for prediction step
# Used in rpf_bridge()
#' @importFrom data.table .SD ':=' as.data.table
#' @importFrom data.table .SD as.data.table
preprocess_predictors_fit <- function(processed) {
predictors <- as.data.table(processed$predictors)

Expand Down
21 changes: 21 additions & 0 deletions man/convert_xgboost_rpf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions randomPlantedForest.Rproj
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ LineEndingConversion: Posix

BuildType: Package
PackageUseDevtools: Yes
PackageCleanBeforeInstall: No
PackageInstallArgs: --no-multiarch --with-keep.source
PackageRoxygenize: rd,collate,namespace
3 changes: 2 additions & 1 deletion src/include/rpf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class RandomPlantedForest
void get_parameters();
void set_parameters(StringVector keys, NumericVector values);
List get_model();
void set_model(List& model);
virtual ~RandomPlantedForest(){};
bool is_purified();

Expand Down Expand Up @@ -57,4 +58,4 @@ class RandomPlantedForest
std::multimap<int, std::shared_ptr<DecisionTree>> &possible_splits, TreeFamily &curr_family);
};

#endif // RPF_HPP
#endif // RPF_HPP
51 changes: 51 additions & 0 deletions src/lib/rpf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2088,3 +2088,54 @@ List RandomPlantedForest::get_model()
}
return (model);
}

// FIXME: Add checks for format of model argument
void RandomPlantedForest::set_model(List& model) {

unsigned int n_families = model.size();

for (unsigned int i = 0; i < n_families; ++i) {
List family = model[i];
List variables = family["variables"];
List values = family["values"];
List intervals = family["intervals"];
unsigned int n_trees = variables.size();

TreeFamily& old_family = tree_families[i];

old_family.clear();

for (unsigned int j = 0; j < n_trees; ++j) {

// Variables
IntegerVector tree_variables = variables[j];
std::set<int> tree_variables_set(tree_variables.begin(), tree_variables.end());

// Values
std::vector<std::vector<double>> tree_values = values[j];
unsigned int n_leaves = tree_values.size();

// Intervals
std::vector<NumericMatrix> tree_intervals = intervals[j];

std::vector<Leaf> leaves;
for (unsigned int k = 0; k < n_leaves; ++k) {

Leaf temp;
temp.value = tree_values[k];

NumericMatrix leaf_intervals = tree_intervals[k];
std::vector<Interval> intervals(leaf_intervals.ncol());
for (int l = 0; l < feature_size; ++l) {
intervals[l] = Interval{leaf_intervals(0, l), leaf_intervals(1, l)};
}

temp.intervals = intervals;
leaves.push_back(temp);
}

old_family.insert(std::make_pair(tree_variables_set, std::make_shared<DecisionTree>(DecisionTree(tree_variables_set, leaves))));

}
}
}
1 change: 1 addition & 0 deletions src/randomPlantedForest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ RCPP_MODULE(mod_rpf)
.method("get_parameters", &RandomPlantedForest::get_parameters)
.method("set_parameters", &RandomPlantedForest::set_parameters)
.method("get_model", &RandomPlantedForest::get_model)
.method("set_model", &RandomPlantedForest::set_model)
.method("is_purified", &RandomPlantedForest::is_purified);

class_<ClassificationRPF>("ClassificationRPF")
Expand Down