Skip to content

Commit 79afe8e

Browse files
committed
Refactor parameter handling in RcppRPF and ClassificationRPF
- Removed `set_parameters` method from both RcppRPF and ClassificationRPF classes. - Updated `get_parameters` to return a structured RPFParams object instead of printing to stdout. - Removed obsolete test related to parameter setting from test suite.
1 parent 4e5b8bc commit 79afe8e

File tree

8 files changed

+70
-219
lines changed

8 files changed

+70
-219
lines changed

src/include/cpf.hpp

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,32 @@
33

44
#include <vector>
55
#include "rpf.hpp"
6+
enum LossType
7+
{
8+
L1,
9+
L2,
10+
median,
11+
logit,
12+
logit_2,
13+
logit_3,
14+
logit_4,
15+
exponential,
16+
exponential_2,
17+
exponential_3
18+
};
19+
struct CPFParams
20+
{
21+
int max_interaction;
22+
int n_trees;
23+
int n_splits;
24+
int split_try;
25+
double t_try;
26+
bool purify_forest;
27+
bool deterministic;
28+
int nthreads;
29+
bool cross_validate;
30+
LossType loss;
31+
};
632

733
class ClassificationRPF : public RandomPlantedForest
834
{
@@ -11,25 +37,12 @@ class ClassificationRPF : public RandomPlantedForest
1137
using RandomPlantedForest::calcOptimalSplit;
1238
ClassificationRPF(const std::vector<std::vector<double>> &samples_Y, const std::vector<std::vector<double>> &samples_X,
1339
const std::string loss = "L2", const std::vector<double> parameters = {1, 50, 30, 10, 0.4, 0, 0, 0, 0, 0, 0.1});
14-
void set_parameters(std::vector<std::string> keys, std::vector<double> values);
15-
~ClassificationRPF(){};
40+
~ClassificationRPF() {};
41+
CPFParams get_parameters();
1642

1743
private:
1844
double delta;
1945
double epsilon;
20-
enum LossType
21-
{
22-
L1,
23-
L2,
24-
median,
25-
logit,
26-
logit_2,
27-
logit_3,
28-
logit_4,
29-
exponential,
30-
exponential_2,
31-
exponential_3
32-
};
3346
LossType loss;
3447
void (ClassificationRPF::*calcLoss)(Split &);
3548
void create_tree_family(std::vector<Leaf> initial_leaves, size_t n) override;
@@ -48,5 +61,4 @@ class ClassificationRPF : public RandomPlantedForest
4861
void exponential_loss_3(Split &split);
4962
};
5063

51-
5264
#endif

src/include/rcpp_interface.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ using namespace Rcpp;
1010
class RcppInterface
1111
{
1212
public:
13-
virtual void set_parameters(StringVector keys, NumericVector values) = 0;
1413
virtual NumericMatrix predict_matrix(const NumericMatrix &X, const NumericVector components) = 0;
1514
virtual NumericMatrix predict_vector(const NumericVector &X, const NumericVector components) = 0;
1615
virtual void cross_validation(int n_sets, IntegerVector splits, NumericVector t_tries, IntegerVector split_tries) = 0;
@@ -29,12 +28,11 @@ class RcppRPF : public RandomPlantedForest, public RcppInterface
2928
NumericMatrix predict_vector(const NumericVector &X, const NumericVector components = {0}) override;
3029
void cross_validation(int n_sets, IntegerVector splits, NumericVector t_tries, IntegerVector split_tries) override;
3130
double MSE(const NumericMatrix &Y_predicted, const NumericMatrix &Y_true) override;
32-
void set_parameters(StringVector keys, NumericVector values) override;
3331
List get_model() override;
3432

3533
void purify_3();
3634
void print();
37-
void get_parameters();
35+
List get_parameters();
3836
bool is_purified();
3937

4038
protected:
@@ -46,7 +44,6 @@ class RcppCPF : public ClassificationRPF, public RcppInterface
4644
public:
4745
RcppCPF(const NumericMatrix &samples_Y, const NumericMatrix &samples_X,
4846
const std::string loss = "L2", const NumericVector parameters = {1, 50, 30, 10, 0.4, 0, 0, 0, 0});
49-
void set_parameters(StringVector keys, NumericVector values) override;
5047
NumericMatrix predict_matrix(const NumericMatrix &X, const NumericVector components = {0}) override;
5148
NumericMatrix predict_vector(const NumericVector &X, const NumericVector components = {0}) override;
5249
void cross_validation(int n_sets, IntegerVector splits, NumericVector t_tries, IntegerVector split_tries) override;
@@ -55,7 +52,7 @@ class RcppCPF : public ClassificationRPF, public RcppInterface
5552

5653
void purify_3();
5754
void print();
58-
void get_parameters();
55+
List get_parameters();
5956
bool is_purified();
6057
};
6158

src/include/rpf.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,19 @@
55
#include <stdexcept>
66
#include "trees.hpp"
77

8+
struct RPFParams
9+
{
10+
int max_interaction;
11+
int n_trees;
12+
int n_splits;
13+
int split_try;
14+
double t_try;
15+
bool purify_forest;
16+
bool deterministic;
17+
int nthreads;
18+
bool cross_validate;
19+
};
20+
821
class RandomPlantedForest
922
{
1023

@@ -23,8 +36,7 @@ class RandomPlantedForest
2336
void print();
2437
void cross_validation(int n_sets = 4, std::vector<int> splits = {5, 50}, std::vector<double> t_tries = {0.2, 0.5, 0.7, 0.9}, std::vector<int> split_tries = {1, 2, 5, 10});
2538
double MSE(const std::vector<std::vector<double>> &Y_predicted, const std::vector<std::vector<double>> &Y_true);
26-
void get_parameters();
27-
void set_parameters(std::vector<std::string> keys, std::vector<double> values);
39+
RPFParams get_parameters();
2840
virtual ~RandomPlantedForest() {};
2941
bool is_purified();
3042

src/lib/cpf.cpp

Lines changed: 2 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,100 +1434,7 @@ void ClassificationRPF::fit()
14341434
}
14351435
}
14361436

1437-
/* retrospectively change parameters of existing class object,
1438-
updates the model, so far only single valued parameters supported,
1439-
for replacing training data use 'set_data',
1440-
note that changing cv does not trigger cross validation */
1441-
void ClassificationRPF::set_parameters(std::vector<std::string> keys, std::vector<double> values)
1437+
CPFParams ClassificationRPF::get_parameters()
14421438
{
1443-
if (keys.size() != values.size())
1444-
{
1445-
std::cout << "Size of input vectors is not the same. " << std::endl;
1446-
return;
1447-
}
1448-
1449-
for (unsigned int i = 0; i < keys.size(); ++i)
1450-
{
1451-
if (keys[i] == "deterministic")
1452-
{
1453-
this->deterministic = values[i];
1454-
}
1455-
else if (keys[i] == "nthreads")
1456-
{
1457-
this->nthreads = values[i];
1458-
}
1459-
else if (keys[i] == "purify")
1460-
{
1461-
this->purify_forest = values[i];
1462-
}
1463-
else if (keys[i] == "n_trees")
1464-
{
1465-
this->n_trees = values[i];
1466-
}
1467-
else if (keys[i] == "n_splits")
1468-
{
1469-
this->n_splits = values[i];
1470-
}
1471-
else if (keys[i] == "t_try")
1472-
{
1473-
this->t_try = values[i];
1474-
}
1475-
else if (keys[i] == "split_try")
1476-
{
1477-
this->split_try = values[i];
1478-
}
1479-
else if (keys[i] == "max_interaction")
1480-
{
1481-
this->max_interaction = values[i];
1482-
}
1483-
else if (keys[i] == "cv")
1484-
{
1485-
this->cross_validate = values[i];
1486-
}
1487-
else if (keys[i] == "loss")
1488-
{
1489-
if (keys[i] == "L1")
1490-
{
1491-
this->loss = LossType::L1;
1492-
this->calcLoss = &ClassificationRPF::L1_loss;
1493-
}
1494-
else if (keys[i] == "L2")
1495-
{
1496-
this->loss = LossType::L2;
1497-
this->calcLoss = &ClassificationRPF::L2_loss;
1498-
}
1499-
else if (keys[i] == "median")
1500-
{
1501-
this->loss = LossType::median;
1502-
this->calcLoss = &ClassificationRPF::median_loss;
1503-
}
1504-
else if (keys[i] == "logit")
1505-
{
1506-
this->loss = LossType::logit;
1507-
this->calcLoss = &ClassificationRPF::logit_loss;
1508-
}
1509-
else if (keys[i] == "exponential")
1510-
{
1511-
this->loss = LossType::exponential;
1512-
this->calcLoss = &ClassificationRPF::exponential_loss;
1513-
}
1514-
else
1515-
{
1516-
std::cout << "Unkown loss function." << std::endl;
1517-
}
1518-
}
1519-
else if (keys[i] == "delta")
1520-
{
1521-
this->delta = values[i];
1522-
}
1523-
else if (keys[i] == "epsilon")
1524-
{
1525-
this->epsilon = values[i];
1526-
}
1527-
else
1528-
{
1529-
std::cout << "Unkown parameter key '" << keys[i] << "' ." << std::endl;
1530-
}
1531-
}
1532-
this->fit();
1439+
return {max_interaction, n_trees, n_splits, split_try, t_try, purify_forest, deterministic, nthreads, cross_validate, loss};
15331440
}

src/lib/rcpp_interface.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,6 @@ double RcppRPF::MSE(const NumericMatrix &Y_predicted, const NumericMatrix &Y_tru
114114
return RandomPlantedForest::MSE(toStd2D(Y_predicted), toStd2D(Y_true));
115115
}
116116

117-
void RcppRPF::set_parameters(StringVector keys, NumericVector values)
118-
{
119-
RandomPlantedForest::set_parameters(std::vector<std::string>(keys.begin(), keys.end()),
120-
std::vector<double>(values.begin(), values.end()));
121-
}
122-
123117
double RcppRPF::MSE_vec(const NumericVector &Y_predicted, const NumericVector &Y_true)
124118
{
125119
return RandomPlantedForest::MSE_vec(toStd1D(Y_predicted), toStd1D(Y_true));
@@ -135,9 +129,18 @@ void RcppRPF::print()
135129
RandomPlantedForest::print();
136130
}
137131

138-
void RcppRPF::get_parameters()
132+
List RcppRPF::get_parameters()
139133
{
140-
RandomPlantedForest::get_parameters();
134+
RPFParams params = RandomPlantedForest::get_parameters();
135+
return List::create(Named("max_interaction") = params.max_interaction,
136+
Named("n_trees") = params.n_trees,
137+
Named("n_splits") = params.n_splits,
138+
Named("split_try") = params.split_try,
139+
Named("t_try") = params.t_try,
140+
Named("purify_forest") = params.purify_forest,
141+
Named("deterministic") = params.deterministic,
142+
Named("nthreads") = params.nthreads,
143+
Named("cross_validate") = params.cross_validate);
141144
}
142145

143146
List RcppRPF::get_model()
@@ -163,12 +166,6 @@ RcppCPF::RcppCPF(const NumericMatrix &samples_Y, const NumericMatrix &samples_X,
163166
}
164167
}
165168

166-
void RcppCPF::set_parameters(StringVector keys, NumericVector values)
167-
{
168-
ClassificationRPF::set_parameters(std::vector<std::string>(keys.begin(), keys.end()),
169-
std::vector<double>(values.begin(), values.end()));
170-
}
171-
172169
NumericMatrix RcppCPF::predict_matrix(const NumericMatrix &X, const NumericVector components)
173170
{
174171
auto result = ClassificationRPF::predict_matrix(toStd2D(X), toStd1D(components));
@@ -219,9 +216,18 @@ void RcppCPF::print()
219216
RandomPlantedForest::print();
220217
}
221218

222-
void RcppCPF::get_parameters()
219+
List RcppCPF::get_parameters()
223220
{
224-
RandomPlantedForest::get_parameters();
221+
CPFParams params = ClassificationRPF::get_parameters();
222+
return List::create(Named("max_interaction") = params.max_interaction,
223+
Named("n_trees") = params.n_trees,
224+
Named("n_splits") = params.n_splits,
225+
Named("split_try") = params.split_try,
226+
Named("t_try") = params.t_try,
227+
Named("purify_forest") = params.purify_forest,
228+
Named("deterministic") = params.deterministic,
229+
Named("nthreads") = params.nthreads,
230+
Named("cross_validate") = params.cross_validate);
225231
}
226232

227233
List RcppCPF::get_model()

src/lib/rpf.cpp

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,68 +1983,7 @@ void RandomPlantedForest::print()
19831983
}
19841984
}
19851985

1986-
// print parameters of the model to the console
1987-
void RandomPlantedForest::get_parameters()
1986+
RPFParams RandomPlantedForest::get_parameters()
19881987
{
1989-
std::cout << "Parameters: n_trees=" << n_trees << ", n_splits=" << n_splits << ", max_interaction=" << max_interaction << ", t_try=" << t_try
1990-
<< ", split_try=" << split_try << ", purified=" << purified << ", deterministic=" << deterministic << ", nthreads=" << nthreads
1991-
<< ", feature_size=" << feature_size << ", sample_size=" << sample_size << std::endl;
1992-
}
1993-
1994-
/* retrospectively change parameters of existing class object,
1995-
updates the model, so far only single valued parameters supported,
1996-
for replacing training data use 'set_data',
1997-
note that changing cv does not trigger cross validation */
1998-
void RandomPlantedForest::set_parameters(std::vector<std::string> keys, std::vector<double> values)
1999-
{
2000-
if (keys.size() != values.size())
2001-
{
2002-
std::cout << "Size of input vectors is not the same. " << std::endl;
2003-
return;
2004-
}
2005-
2006-
for (unsigned int i = 0; i < keys.size(); ++i)
2007-
{
2008-
if (keys[i] == "deterministic")
2009-
{
2010-
this->deterministic = values[i];
2011-
}
2012-
else if (keys[i] == "nthreads")
2013-
{
2014-
this->nthreads = values[i];
2015-
}
2016-
else if (keys[i] == "purify")
2017-
{
2018-
this->purify_forest = values[i];
2019-
}
2020-
else if (keys[i] == "n_trees")
2021-
{
2022-
this->n_trees = values[i];
2023-
}
2024-
else if (keys[i] == "n_splits")
2025-
{
2026-
this->n_splits = values[i];
2027-
}
2028-
else if (keys[i] == "t_try")
2029-
{
2030-
this->t_try = values[i];
2031-
}
2032-
else if (keys[i] == "split_try")
2033-
{
2034-
this->split_try = values[i];
2035-
}
2036-
else if (keys[i] == "max_interaction")
2037-
{
2038-
this->max_interaction = values[i];
2039-
}
2040-
else if (keys[i] == "cv")
2041-
{
2042-
this->cross_validate = values[i];
2043-
}
2044-
else
2045-
{
2046-
std::cout << "Unkown parameter key '" << keys[i] << "' ." << std::endl;
2047-
}
2048-
}
2049-
this->fit();
1988+
return {max_interaction, n_trees, n_splits, split_try, t_try, purify_forest, deterministic, nthreads, cross_validate};
20501989
}

src/randomPlantedForest.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,18 @@ RCPP_MODULE(mod_rpf)
1616
.method("purify", &RcppRPF::purify_3)
1717
.method("print", &RcppRPF::print)
1818
.method("get_parameters", &RcppRPF::get_parameters)
19-
.method("set_parameters", &RcppRPF::set_parameters)
2019
.method("get_model", &RcppRPF::get_model)
2120
.method("is_purified", &RcppRPF::is_purified);
2221

2322
class_<RcppCPF>("ClassificationRPF")
2423
.constructor<const NumericMatrix, const NumericMatrix, const String, const NumericVector>()
25-
.method("set_parameters", &RcppCPF::set_parameters)
2624
.method("predict_matrix", &RcppCPF::predict_matrix)
2725
.method("predict_vector", &RcppCPF::predict_vector)
2826
.method("cross_validation", &RcppCPF::cross_validation)
2927
.method("MSE", &RcppCPF::MSE)
3028
.method("purify", &RcppCPF::purify_3)
3129
.method("print", &RcppCPF::print)
3230
.method("get_parameters", &RcppCPF::get_parameters)
33-
.method("set_parameters", &RcppCPF::set_parameters)
3431
.method("get_model", &RcppCPF::get_model)
3532
.method("is_purified", &RcppCPF::is_purified);
3633
}

0 commit comments

Comments
 (0)