Skip to content

Commit 442affb

Browse files
committed
Merge remote-tracking branch 'origin/xgboost_to_rpf' into interpolation-fix
2 parents 380ece5 + 1dcda25 commit 442affb

9 files changed

+158
-5
lines changed

DESCRIPTION

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ Imports:
2222
hardhat,
2323
methods,
2424
Rcpp,
25-
utils
25+
utils,
26+
xgboost
2627
Suggests:
2728
knitr,
2829
mvtnorm,
@@ -38,4 +39,4 @@ Config/Needs/website: patchwork, ggplot2
3839
Config/testthat/edition: 3
3940
Encoding: UTF-8
4041
Roxygen: list(markdown = TRUE)
41-
RoxygenNote: 7.2.3
42+
RoxygenNote: 7.3.0

NAMESPACE

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ S3method(rpf,formula)
1111
S3method(rpf,matrix)
1212
S3method(rpf,recipe)
1313
S3method(str,rpf_forest)
14+
export(convert_xgboost_rpf)
1415
export(is_purified)
1516
export(predict_components)
1617
export(purify)
@@ -19,7 +20,6 @@ import(checkmate)
1920
importFrom(Rcpp,loadModule)
2021
importFrom(Rcpp,sourceCpp)
2122
importFrom(data.table,":=")
22-
importFrom(data.table,':=')
2323
importFrom(data.table,.BY)
2424
importFrom(data.table,.EACHI)
2525
importFrom(data.table,.GRP)
@@ -40,4 +40,5 @@ importFrom(methods,new)
4040
importFrom(stats,model.matrix)
4141
importFrom(utils,capture.output)
4242
importFrom(utils,combn)
43+
importFrom(xgboost,xgb.model.dt.tree)
4344
useDynLib(randomPlantedForest, .registration = TRUE)

R/convert_xgboost.R

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
2+
# Function to get leaf bounds from xgboost
3+
#FIXME: This seems to be the bottleneck, could do it in C++
4+
get_leaf_bounds <- function(trees, tree,x) {
5+
max_node <- trees[Tree == tree, max(Node)]
6+
num_nodes <- max_node + 1
7+
lb <- matrix(-Inf, nrow = num_nodes, ncol = ncol(x))
8+
ub <- matrix(Inf, nrow = num_nodes, ncol = ncol(x))
9+
for (nn in 0:max_node) {
10+
if (trees[Tree == tree & Node == nn, !is.na(Yes)]) {
11+
left_child <- trees[Tree == tree & Node == nn, Yes]
12+
right_child <- trees[Tree == tree & Node == nn, No]
13+
splitvar <- trees[Tree == tree & Node == nn, Feature_num]
14+
15+
# Children inherit bounds
16+
ub[left_child + 1, ] <- ub[nn + 1, ]
17+
ub[right_child + 1, ] <- ub[nn + 1, ]
18+
lb[left_child + 1, ] <- lb[nn + 1, ]
19+
lb[right_child + 1, ] <- lb[nn + 1, ]
20+
21+
# Restrict by new split
22+
ub[left_child + 1, splitvar] <- trees[Tree == tree & Node == nn, Split]
23+
lb[right_child + 1, splitvar] <- trees[Tree == tree & Node == nn, Split]
24+
}
25+
}
26+
27+
# Return bounds of leaves only
28+
leaves <- trees[Tree == tree & Feature == "Leaf", Node+1]
29+
list(lower = lb[leaves, ],
30+
upper = ub[leaves, ])
31+
}
32+
33+
#' Convert xgboost to rpf object
34+
#'
35+
#' @param xg xgboost object
36+
#' @param x data used to train the xgboost model
37+
#' @param y target used to train the xgboost model
38+
#'
39+
#' @return rpf object
40+
#' @importFrom xgboost xgb.model.dt.tree
41+
#' @export
42+
convert_xgboost_rpf <- function(xg, x, y) {
43+
trees <- xgboost::xgb.model.dt.tree(model = xg, use_int_id = TRUE)
44+
trees[, Feature_num := as.integer(factor(Feature, levels = c("Leaf", colnames(x)))) - 1L]
45+
num_trees <- trees[, max(Tree)+1]
46+
47+
# create a dummy rpf
48+
rpfit <- rpf(x = x, y = y, max_interaction = 0, ntrees = num_trees, splits = 1,
49+
purify = FALSE)
50+
51+
# Overwrite rpf trees
52+
for (t in seq_len(num_trees)) {
53+
# xgboost adds 0.5 to prediction
54+
rpfit$forest[[t]]$values[[1]][[1]] <- 0.5
55+
rpfit$forest[[t]]$variables[[2]] <- trees[Tree == t-1 & Feature_num > 0, sort(unique(Feature_num))]
56+
rpfit$forest[[t]]$values[[2]] <- as.list(as.numeric(num_trees)*trees[Tree == t-1 & Feature == "Leaf", Quality])
57+
58+
rpfit$forest[[t]]$intervals[[2]] <- rep(rpfit$forest[[t]]$intervals[[1]], length(rpfit$forest[[t]]$values[[2]]))
59+
60+
# Get leaf bounds
61+
leaf_bounds <- get_leaf_bounds(trees, t-1,x)
62+
leaves <- trees[Tree == t-1 & Feature == "Leaf", Node+1]
63+
for (i in seq_along(leaves)) {
64+
rpfit$forest[[t]]$intervals[[2]][[i]][1, ] <- pmax(rpfit$forest[[t]]$intervals[[2]][[i]][1, ],
65+
leaf_bounds$lower[i, ])
66+
rpfit$forest[[t]]$intervals[[2]][[i]][2, ] <- pmin(rpfit$forest[[t]]$intervals[[2]][[i]][2, ],
67+
leaf_bounds$upper[i, ])
68+
}
69+
}
70+
71+
# Also overwrite C++ forest
72+
rpfit$fit$set_model(rpfit$forest)
73+
74+
# Return manipulated rpf object
75+
rpfit
76+
}

R/utils.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ pca_order <- function(x, y) {
7676
# Sort factor predictors by outcome and re-encode as integer
7777
# save original factor levels for prediction step
7878
# Used in rpf_bridge()
79-
#' @importFrom data.table .SD ':=' as.data.table
79+
#' @importFrom data.table .SD as.data.table
8080
preprocess_predictors_fit <- function(processed) {
8181
predictors <- as.data.table(processed$predictors)
8282

man/convert_xgboost_rpf.Rd

+21
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

randomPlantedForest.Rproj

+1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ LineEndingConversion: Posix
1818

1919
BuildType: Package
2020
PackageUseDevtools: Yes
21+
PackageCleanBeforeInstall: No
2122
PackageInstallArgs: --no-multiarch --with-keep.source
2223
PackageRoxygenize: rd,collate,namespace

src/include/rpf.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class RandomPlantedForest
2525
void get_parameters();
2626
void set_parameters(StringVector keys, NumericVector values);
2727
List get_model();
28+
void set_model(List& model);
2829
virtual ~RandomPlantedForest(){};
2930
bool is_purified();
3031

@@ -60,4 +61,4 @@ class RandomPlantedForest
6061
std::vector<std::vector<double>> get_lim_list(const TreeFamily &curr_family);
6162
};
6263

63-
#endif // RPF_HPP
64+
#endif // RPF_HPP

src/lib/rpf.cpp

+51
Original file line numberDiff line numberDiff line change
@@ -1036,3 +1036,54 @@ List RandomPlantedForest::get_model()
10361036
}
10371037
return (model);
10381038
}
1039+
1040+
// FIXME: Add checks for format of model argument
1041+
void RandomPlantedForest::set_model(List& model) {
1042+
1043+
unsigned int n_families = model.size();
1044+
1045+
for (unsigned int i = 0; i < n_families; ++i) {
1046+
List family = model[i];
1047+
List variables = family["variables"];
1048+
List values = family["values"];
1049+
List intervals = family["intervals"];
1050+
unsigned int n_trees = variables.size();
1051+
1052+
TreeFamily& old_family = tree_families[i];
1053+
1054+
old_family.clear();
1055+
1056+
for (unsigned int j = 0; j < n_trees; ++j) {
1057+
1058+
// Variables
1059+
IntegerVector tree_variables = variables[j];
1060+
std::set<int> tree_variables_set(tree_variables.begin(), tree_variables.end());
1061+
1062+
// Values
1063+
std::vector<std::vector<double>> tree_values = values[j];
1064+
unsigned int n_leaves = tree_values.size();
1065+
1066+
// Intervals
1067+
std::vector<NumericMatrix> tree_intervals = intervals[j];
1068+
1069+
std::vector<Leaf> leaves;
1070+
for (unsigned int k = 0; k < n_leaves; ++k) {
1071+
1072+
Leaf temp;
1073+
temp.value = tree_values[k];
1074+
1075+
NumericMatrix leaf_intervals = tree_intervals[k];
1076+
std::vector<Interval> intervals(leaf_intervals.ncol());
1077+
for (int l = 0; l < feature_size; ++l) {
1078+
intervals[l] = Interval{leaf_intervals(0, l), leaf_intervals(1, l)};
1079+
}
1080+
1081+
temp.intervals = intervals;
1082+
leaves.push_back(temp);
1083+
}
1084+
1085+
old_family.insert(std::make_pair(tree_variables_set, std::make_shared<DecisionTree>(DecisionTree(tree_variables_set, leaves))));
1086+
1087+
}
1088+
}
1089+
}

src/randomPlantedForest.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ RCPP_MODULE(mod_rpf)
2020
.method("get_parameters", &RandomPlantedForest::get_parameters)
2121
.method("set_parameters", &RandomPlantedForest::set_parameters)
2222
.method("get_model", &RandomPlantedForest::get_model)
23+
.method("set_model", &RandomPlantedForest::set_model)
2324
.method("is_purified", &RandomPlantedForest::is_purified);
2425

2526
class_<ClassificationRPF>("ClassificationRPF")

0 commit comments

Comments
 (0)