@@ -114,12 +114,6 @@ double RcppRPF::MSE(const NumericMatrix &Y_predicted, const NumericMatrix &Y_tru
114114 return RandomPlantedForest::MSE (toStd2D (Y_predicted), toStd2D (Y_true));
115115}
116116
117- void RcppRPF::set_parameters (StringVector keys, NumericVector values)
118- {
119- RandomPlantedForest::set_parameters (std::vector<std::string>(keys.begin (), keys.end ()),
120- std::vector<double >(values.begin (), values.end ()));
121- }
122-
123117double RcppRPF::MSE_vec (const NumericVector &Y_predicted, const NumericVector &Y_true)
124118{
125119 return RandomPlantedForest::MSE_vec (toStd1D (Y_predicted), toStd1D (Y_true));
@@ -135,9 +129,18 @@ void RcppRPF::print()
135129 RandomPlantedForest::print ();
136130}
137131
138- void RcppRPF::get_parameters ()
132+ List RcppRPF::get_parameters ()
139133{
140- RandomPlantedForest::get_parameters ();
134+ RPFParams params = RandomPlantedForest::get_parameters ();
135+ return List::create (Named (" max_interaction" ) = params.max_interaction ,
136+ Named (" n_trees" ) = params.n_trees ,
137+ Named (" n_splits" ) = params.n_splits ,
138+ Named (" split_try" ) = params.split_try ,
139+ Named (" t_try" ) = params.t_try ,
140+ Named (" purify_forest" ) = params.purify_forest ,
141+ Named (" deterministic" ) = params.deterministic ,
142+ Named (" nthreads" ) = params.nthreads ,
143+ Named (" cross_validate" ) = params.cross_validate );
141144}
142145
143146List RcppRPF::get_model ()
@@ -163,12 +166,6 @@ RcppCPF::RcppCPF(const NumericMatrix &samples_Y, const NumericMatrix &samples_X,
163166 }
164167}
165168
166- void RcppCPF::set_parameters (StringVector keys, NumericVector values)
167- {
168- ClassificationRPF::set_parameters (std::vector<std::string>(keys.begin (), keys.end ()),
169- std::vector<double >(values.begin (), values.end ()));
170- }
171-
172169NumericMatrix RcppCPF::predict_matrix (const NumericMatrix &X, const NumericVector components)
173170{
174171 auto result = ClassificationRPF::predict_matrix (toStd2D (X), toStd1D (components));
@@ -219,9 +216,18 @@ void RcppCPF::print()
219216 RandomPlantedForest::print ();
220217}
221218
222- void RcppCPF::get_parameters ()
219+ List RcppCPF::get_parameters ()
223220{
224- RandomPlantedForest::get_parameters ();
221+ CPFParams params = ClassificationRPF::get_parameters ();
222+ return List::create (Named (" max_interaction" ) = params.max_interaction ,
223+ Named (" n_trees" ) = params.n_trees ,
224+ Named (" n_splits" ) = params.n_splits ,
225+ Named (" split_try" ) = params.split_try ,
226+ Named (" t_try" ) = params.t_try ,
227+ Named (" purify_forest" ) = params.purify_forest ,
228+ Named (" deterministic" ) = params.deterministic ,
229+ Named (" nthreads" ) = params.nthreads ,
230+ Named (" cross_validate" ) = params.cross_validate );
225231}
226232
227233List RcppCPF::get_model ()
0 commit comments