Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 39 additions & 17 deletions R/rpf.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
#' @param epsilon `[0.1]`: Only used if loss = `"logit"` or `"exponential"`.
#' Proportion of class membership is truncated to be smaller 1-epsilon when calculating
#' the fit in a leaf.
#' @param split_decay_rate `[0.1]`: Exponential decay factor λ for aging split-candidates. A candidate’s weight is `exp(−λ * age)`.
#' @param max_candidates `[50]`: Maximum number of split‐candidates to sample at each node (will be clamped to [1, #possible_splits]).
#' @param delete_leaves `[1]`: Whether parents should be deleted if split is an existing coordinate
#' @param ... (Unused).
#'
#' @return Object of class `"rpf"` with model object contained in `$fit`.
Expand Down Expand Up @@ -63,14 +66,16 @@ rpf.default <- function(x, ...) {
#' @export
#' @rdname rpf
rpf.data.frame <- function(x, y, max_interaction = 1, ntrees = 50, splits = 30,
split_try = 10, t_try = 0.4, deterministic = FALSE,
split_try = 10, t_try = 0.4, split_decay_rate = 0.1,
max_candidates = 50, delete_leaves = 1,
deterministic = FALSE,
nthreads = 1, purify = FALSE, cv = FALSE,
loss = "L2", delta = 0, epsilon = 0.1, ...) {
blueprint <- hardhat::default_xy_blueprint(intercept = FALSE)
processed <- hardhat::mold(x, y, blueprint = blueprint)
rpf_bridge(
processed, max_interaction, ntrees, splits,
split_try, t_try, deterministic,
split_try, t_try, split_decay_rate, max_candidates, delete_leaves, deterministic,
nthreads, purify, cv,
loss, delta, epsilon
)
Expand All @@ -80,14 +85,16 @@ rpf.data.frame <- function(x, y, max_interaction = 1, ntrees = 50, splits = 30,
#' @export
#' @rdname rpf
rpf.matrix <- function(x, y, max_interaction = 1, ntrees = 50, splits = 30,
split_try = 10, t_try = 0.4, deterministic = FALSE,
split_try = 10, t_try = 0.4, split_decay_rate = 0.1,
max_candidates = 50, delete_leaves = 1,
deterministic = FALSE,
nthreads = 1, purify = FALSE, cv = FALSE,
loss = "L2", delta = 0, epsilon = 0.1, ...) {
blueprint <- hardhat::default_xy_blueprint(intercept = FALSE)
processed <- hardhat::mold(x, y, blueprint = blueprint)
rpf_bridge(
processed, max_interaction, ntrees, splits,
split_try, t_try, deterministic,
split_try, t_try, split_decay_rate, max_candidates, delete_leaves, deterministic,
nthreads, purify, cv,
loss, delta, epsilon
)}
Expand All @@ -96,14 +103,16 @@ rpf.matrix <- function(x, y, max_interaction = 1, ntrees = 50, splits = 30,
#' @export
#' @rdname rpf
rpf.formula <- function(formula, data, max_interaction = 1, ntrees = 50, splits = 30,
split_try = 10, t_try = 0.4, deterministic = FALSE,
split_try = 10, t_try = 0.4, split_decay_rate = 0.1,
max_candidates = 50, delete_leaves = 1,
deterministic = FALSE,
nthreads = 1, purify = FALSE, cv = FALSE,
loss = "L2", delta = 0, epsilon = 0.1, ...) {
blueprint <- hardhat::default_formula_blueprint(intercept = FALSE, indicators = "none")
processed <- hardhat::mold(formula, data, blueprint = blueprint)
rpf_bridge(
processed, max_interaction, ntrees, splits,
split_try, t_try, deterministic,
split_try, t_try, split_decay_rate, max_candidates, delete_leaves, deterministic,
nthreads, purify, cv,
loss, delta, epsilon
)
Expand All @@ -113,14 +122,16 @@ rpf.formula <- function(formula, data, max_interaction = 1, ntrees = 50, splits
#' @export
#' @rdname rpf
rpf.recipe <- function(x, data, max_interaction = 1, ntrees = 50, splits = 30,
split_try = 10, t_try = 0.4, deterministic = FALSE,
split_try = 10, t_try = 0.4, split_decay_rate = 0.1,
max_candidates = 50, delete_leaves = 1,
deterministic = FALSE,
nthreads = 1, purify = FALSE, cv = FALSE,
loss = "L2", delta = 0, epsilon = 0.1, ...) {
blueprint <- hardhat::default_recipe_blueprint(intercept = FALSE)
processed <- hardhat::mold(x, data, blueprint = blueprint)
rpf_bridge(
processed, max_interaction, ntrees, splits,
split_try, t_try, deterministic,
split_try, t_try, split_decay_rate, max_candidates, delete_leaves, deterministic,
nthreads, purify, cv,
loss, delta, epsilon
)
Expand All @@ -131,7 +142,9 @@ rpf.recipe <- function(x, data, max_interaction = 1, ntrees = 50, splits = 30,
#' @param processed Output of `hardhat::mold` from respective rpf methods
#' @importFrom hardhat validate_outcomes_are_univariate
rpf_bridge <- function(processed, max_interaction = 1, ntrees = 50, splits = 30,
split_try = 10, t_try = 0.4, deterministic = FALSE,
split_try = 10, t_try = 0.4, split_decay_rate = 0.1,
max_candidates = 50, delete_leaves = 1,
deterministic = FALSE,
nthreads = 1, purify = FALSE, cv = FALSE,
loss = "L2", delta = 0, epsilon = 0.1) {
hardhat::validate_outcomes_are_univariate(processed$outcomes)
Expand All @@ -141,7 +154,7 @@ rpf_bridge <- function(processed, max_interaction = 1, ntrees = 50, splits = 30,

# Check arguments
checkmate::assert_int(max_interaction, lower = 0)

# rewrite max_interaction so 0 -> "maximum", e.g. ncol(X)
if (max_interaction == 0) {
max_interaction <- p
Expand All @@ -156,10 +169,13 @@ rpf_bridge <- function(processed, max_interaction = 1, ntrees = 50, splits = 30,
checkmate::assert_int(ntrees, lower = 1)
checkmate::assert_int(splits, lower = 1)
checkmate::assert_int(split_try, lower = 1)

checkmate::assert_int(max_candidates, lower = 1)

checkmate::assert_number(t_try, lower = 0, upper = 1)
checkmate::assert_number(delta, lower = 0, upper = 1)
checkmate::assert_number(epsilon, lower = 0, upper = 1)
checkmate::assert_number(split_decay_rate, lower = 0)


# "median" loss is implemented but discarded
loss_functions <- switch(outcomes$mode,
Expand All @@ -172,12 +188,14 @@ rpf_bridge <- function(processed, max_interaction = 1, ntrees = 50, splits = 30,
checkmate::assert_int(nthreads, lower = 1L)
checkmate::assert_flag(purify)
checkmate::assert_flag(cv)
checkmate::assert_flag(delete_leaves)


fit <- rpf_impl(
Y = outcomes$outcomes, X = predictors$predictors_matrix,
mode = outcomes$mode,
max_interaction = max_interaction, ntrees = ntrees, splits = splits,
split_try = split_try, t_try = t_try, deterministic = deterministic,
split_try = split_try, t_try = t_try, split_decay_rate = split_decay_rate, max_candidates = max_candidates, delete_leaves=delete_leaves, deterministic = deterministic,
nthreads = nthreads, purify = purify, cv = cv,
loss = loss, delta = delta, epsilon = epsilon
)
Expand All @@ -195,7 +213,11 @@ rpf_bridge <- function(processed, max_interaction = 1, ntrees = 50, splits = 30,
ntrees = ntrees,
max_interaction = max_interaction,
splits = splits,
split_try = split_try, t_try = t_try,
split_try = split_try,
t_try = t_try,
split_decay_rate = split_decay_rate,
max_candidates = max_candidates,
delete_leaves = delete_leaves,
delta = delta, epsilon = epsilon,
deterministic = deterministic,
nthreads = nthreads, purify = purify, cv = cv
Expand All @@ -217,7 +239,7 @@ new_rpf <- function(fit, blueprint, ...) {
# Main fitting function and interface to C++ implementation
rpf_impl <- function(Y, X, mode = c("regression", "classification"),
max_interaction = 1, ntrees = 50, splits = 30, split_try = 10, t_try = 0.4,
deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE,
deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE, split_decay_rate = 0.1, max_candidates = 50, delete_leaves = 1,
loss = "L2", delta = 0, epsilon = 0.1) {
# Final input validation, should be superfluous
checkmate::assert_matrix(X, mode = "numeric", any.missing = FALSE)
Expand All @@ -226,12 +248,12 @@ rpf_impl <- function(Y, X, mode = c("regression", "classification"),
if (mode == "classification") {
fit <- new(ClassificationRPF, Y, X, loss, c(
max_interaction, ntrees, splits, split_try, t_try,
purify, deterministic, nthreads, cv, delta, epsilon
purify, deterministic, nthreads, cv, split_decay_rate, max_candidates, delete_leaves, delta, epsilon
))
} else if (mode == "regression") {
fit <- new(RandomPlantedForest, Y, X, c(
max_interaction, ntrees, splits, split_try, t_try,
purify, deterministic, nthreads, cv
max_interaction, ntrees, splits, split_try, t_try,
purify, deterministic, nthreads, cv, split_decay_rate, max_candidates, delete_leaves
))
}

Expand Down
13 changes: 8 additions & 5 deletions src/include/cpf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ClassificationRPF : public RandomPlantedForest
public:
using RandomPlantedForest::calcOptimalSplit;
ClassificationRPF(const NumericMatrix &samples_Y, const NumericMatrix &samples_X,
const String loss = "L2", const NumericVector parameters = {1, 50, 30, 10, 0.4, 0, 0, 0, 0, 0, 0.1});
const String loss = "L2", const NumericVector parameters = {1, 50, 30, 10, 0.4, 0, 0, 0, 0, 0.1, 0, 0.1, 50,1});
void set_parameters(StringVector keys, NumericVector values);
~ClassificationRPF(){};

Expand All @@ -33,9 +33,12 @@ class ClassificationRPF : public RandomPlantedForest
void (ClassificationRPF::*calcLoss)(Split &);
void create_tree_family(std::vector<Leaf> initial_leaves, size_t n) override;
void fit() override;
Split calcOptimalSplit(const std::vector<std::vector<double>> &Y, const std::vector<std::vector<double>> &X,
std::multimap<int, std::shared_ptr<DecisionTree>> &possible_splits, TreeFamily &curr_family,
std::vector<std::vector<double>> &weights);
Split calcOptimalSplit(
const std::vector<std::vector<double>>& Y,
const std::vector<std::vector<double>>& X,
std::vector<SplitCandidate>& possible_splits,
TreeFamily& curr_family,
std::vector<std::vector<double>>& weights) ;
void L1_loss(Split &split);
void median_loss(Split &split);
void logit_loss(Split &split);
Expand All @@ -47,4 +50,4 @@ class ClassificationRPF : public RandomPlantedForest
void exponential_loss_3(Split &split);
};

#endif
#endif
32 changes: 27 additions & 5 deletions src/include/rpf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class RandomPlantedForest

public:
RandomPlantedForest(const NumericMatrix &samples_Y, const NumericMatrix &samples_X,
const NumericVector parameters = {1, 50, 30, 10, 0.4, 0, 0, 0, 0});
const NumericVector parameters = {1, 50, 30, 10, 0.4, 0, 0, 0, 0, 0.1, 50,1});
RandomPlantedForest(){};
void set_data(const NumericMatrix &samples_Y, const NumericMatrix &samples_X);
NumericMatrix predict_matrix(const NumericMatrix &X, const NumericVector components = {0});
Expand All @@ -26,7 +26,7 @@ class RandomPlantedForest
List get_model();
virtual ~RandomPlantedForest(){};
bool is_purified();

protected:
double MSE_vec(const NumericVector &Y_predicted, const NumericVector &Y_true);
std::vector<std::vector<double>> X; /**< Nested vector feature samples of size (sample_size x feature_size) */
Expand All @@ -53,8 +53,30 @@ class RandomPlantedForest
void L2_loss(Split &split);
virtual void fit();
virtual void create_tree_family(std::vector<Leaf> initial_leaves, size_t n);
virtual Split calcOptimalSplit(const std::vector<std::vector<double>> &Y, const std::vector<std::vector<double>> &X,
std::multimap<int, std::shared_ptr<DecisionTree>> &possible_splits, TreeFamily &curr_family);
struct SplitCandidate;
// overload possibleExists for your vector of SplitCandidate
static bool possibleExists(
int dim,
const std::vector<SplitCandidate>& possible_splits,
const std::set<int>& resulting_dims
);
virtual Split calcOptimalSplit(const std::vector<std::vector<double>> &Y,
const std::vector<std::vector<double>> &X,
std::vector<SplitCandidate> &possible_splits,
TreeFamily &curr_family);
// exponential‐decay rate for split age
double split_decay_rate_;
size_t max_candidates_;
// track each split candidate and how long it’s sat unchosen
struct SplitCandidate {
int dim;
std::shared_ptr<DecisionTree> tree;
double age;
// single ctor with default age
SplitCandidate(int d, std::shared_ptr<DecisionTree> t, double a = 0.0)
: dim(d), tree(std::move(t)), age(a) {}
};
bool delete_leaves;
};

#endif // RPF_HPP
#endif // RPF_HPP
Loading
Loading