Skip to content

Commit e4266f4

Browse files
authored
Broke ResspectClassifier.predict apart into two methods 1) predict_class and 2) predict_probabilities. (#54)
1 parent 8398e9a commit e4266f4

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

src/resspect/classifiers.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,36 @@ def predict(self, test_features):
100100

101101
return predictions, prob
102102

103+
def predict_class(self, test_features):
104+
"""Predict the class of the test sample using the trained classifier.
105+
106+
Parameters
107+
----------
108+
test_features : array-like
109+
The features used for testing, [n_samples, m_features].
110+
111+
Returns
112+
-------
113+
np.array
114+
The predicted classes for the test sample. [n_samples]
115+
"""
116+
return self.classifier.predict(test_features)
117+
118+
119+
def predict_probabilities(self, test_features):
120+
"""Predict the probabilities of the test sample using the trained classifier.
121+
122+
Parameters
123+
----------
124+
test_features : array-like
125+
The features used for testing, [n_samples, m_features].
126+
127+
Returns
128+
-------
129+
np.array
130+
The predicted probabilities for the test sample. [n_samples, m_classes]
131+
"""
132+
return self.classifier.predict_proba(test_features)
103133

104134
class RandomForest(ResspectClassifier):
105135
"""RESSPECT-specific version of the sklearn RandomForestClassifier."""
@@ -187,7 +217,7 @@ def bootstrap_clf(clf_class, n_ensembles, train_features,
187217
x_train, y_train = resample(train_features, train_labels)
188218
clf = clf_class(**kwargs)
189219
clf.fit(x_train, y_train)
190-
_, class_prob = clf.predict(test_features)
220+
class_prob = clf.predict_probabilities(test_features)
191221

192222
classifiers.append(clf)
193223
ensemble_probs[:, i, :] = class_prob

src/resspect/database.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,10 +954,11 @@ def classify(self, method: str, save_predictions=False, pred_dir=None,
954954

955955
# Fit the classifier and predict with it
956956
clf_instance.fit(self.train_features, self.train_labels)
957-
self.predicted_class, self.classprob = clf_instance.predict(self.pool_features)
957+
self.classprob = clf_instance.predict_probabilities(self.pool_features)
958958

959959
# estimate classification for validation sample
960-
self.validation_class, self.validation_prob = clf_instance.predict(self.validation_features)
960+
self.validation_class = clf_instance.predict_class(self.validation_features)
961+
self.validation_prob = clf_instance.predict_probabilities(self.validation_features)
961962

962963
if save_predictions:
963964
id_name = self.identify_keywords()

0 commit comments

Comments
 (0)