@@ -100,6 +100,36 @@ def predict(self, test_features):
100
100
101
101
return predictions , prob
102
102
103
+ def predict_class (self , test_features ):
104
+ """Predict the class of the test sample using the trained classifier.
105
+
106
+ Parameters
107
+ ----------
108
+ test_features : array-like
109
+ The features used for testing, [n_samples, m_features].
110
+
111
+ Returns
112
+ -------
113
+ np.array
114
+ The predicted classes for the test sample. [n_samples]
115
+ """
116
+ return self .classifier .predict (test_features )
117
+
118
+
119
+ def predict_probabilities (self , test_features ):
120
+ """Predict the probabilities of the test sample using the trained classifier.
121
+
122
+ Parameters
123
+ ----------
124
+ test_features : array-like
125
+ The features used for testing, [n_samples, m_features].
126
+
127
+ Returns
128
+ -------
129
+ np.array
130
+ The predicted probabilities for the test sample. [n_samples, m_classes]
131
+ """
132
+ return self .classifier .predict_proba (test_features )
103
133
104
134
class RandomForest (ResspectClassifier ):
105
135
"""RESSPECT-specific version of the sklearn RandomForestClassifier."""
@@ -187,7 +217,7 @@ def bootstrap_clf(clf_class, n_ensembles, train_features,
187
217
x_train , y_train = resample (train_features , train_labels )
188
218
clf = clf_class (** kwargs )
189
219
clf .fit (x_train , y_train )
190
- _ , class_prob = clf .predict (test_features )
220
+ class_prob = clf .predict_probabilities (test_features )
191
221
192
222
classifiers .append (clf )
193
223
ensemble_probs [:, i , :] = class_prob
0 commit comments