Skip to content

Commit 4e5b8bc

Browse files
committed
Move fitting away from set_data
1 parent f86881a commit 4e5b8bc

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

src/lib/rcpp_interface.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ RcppRPF::RcppRPF(const NumericMatrix &samples_Y, const NumericMatrix &samples_X,
6565
toStd1D(parameters))
6666
{
6767
utils::RandomGenerator::use_r_random();
68+
69+
this->fit();
70+
71+
if (cross_validate)
72+
{
73+
RandomPlantedForest::cross_validation();
74+
}
6875
}
6976

7077
NumericMatrix RcppRPF::predict_matrix(const NumericMatrix &X, const NumericVector components)
@@ -148,6 +155,12 @@ RcppCPF::RcppCPF(const NumericMatrix &samples_Y, const NumericMatrix &samples_X,
148155
: ClassificationRPF(toStd2D(samples_Y), toStd2D(samples_X), loss, toStd1D(parameters))
149156
{
150157
utils::RandomGenerator::use_r_random();
158+
RandomPlantedForest::fit();
159+
160+
if (cross_validate)
161+
{
162+
ClassificationRPF::cross_validation();
163+
}
151164
}
152165

153166
void RcppCPF::set_parameters(StringVector keys, NumericVector values)

src/lib/rpf.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,6 @@ void RandomPlantedForest::set_data(const std::vector<std::vector<double>> &sampl
279279
this->upper_bounds[i] = maxVal + 2 * eps; // to consider samples at max value
280280
this->lower_bounds[i] = minVal;
281281
}
282-
283-
this->fit();
284-
285-
if (cross_validate)
286-
{
287-
this->cross_validation();
288-
}
289282
}
290283

291284
void RandomPlantedForest::create_tree_family(std::vector<Leaf> initial_leaves, size_t n)

0 commit comments

Comments
 (0)