Skip to content

Commit

Permalink
Broke ResspectClassifier.predict apart into two methods 1) `predict…
Browse files Browse the repository at this point in the history
…_class` and 2) `predict_probabilities`. (#54)
  • Loading branch information
drewoldag authored Oct 18, 2024
1 parent 8398e9a commit e4266f4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
32 changes: 31 additions & 1 deletion src/resspect/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,36 @@ def predict(self, test_features):

return predictions, prob

def predict_class(self, test_features):
"""Predict the class of the test sample using the trained classifier.
Parameters
----------
test_features : array-like
The features used for testing, [n_samples, m_features].
Returns
-------
np.array
The predicted classes for the test sample. [n_samples]
"""
return self.classifier.predict(test_features)


def predict_probabilities(self, test_features):
"""Predict the probabilities of the test sample using the trained classifier.
Parameters
----------
test_features : array-like
The features used for testing, [n_samples, m_features].
Returns
-------
np.array
The predicted probabilities for the test sample. [n_samples, m_classes]
"""
return self.classifier.predict_proba(test_features)

class RandomForest(ResspectClassifier):
"""RESSPECT-specific version of the sklearn RandomForestClassifier."""
Expand Down Expand Up @@ -187,7 +217,7 @@ def bootstrap_clf(clf_class, n_ensembles, train_features,
x_train, y_train = resample(train_features, train_labels)
clf = clf_class(**kwargs)
clf.fit(x_train, y_train)
_, class_prob = clf.predict(test_features)
class_prob = clf.predict_probabilities(test_features)

classifiers.append(clf)
ensemble_probs[:, i, :] = class_prob
Expand Down
5 changes: 3 additions & 2 deletions src/resspect/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,10 +954,11 @@ def classify(self, method: str, save_predictions=False, pred_dir=None,

# Fit the classifier and predict with it
clf_instance.fit(self.train_features, self.train_labels)
self.predicted_class, self.classprob = clf_instance.predict(self.pool_features)
self.classprob = clf_instance.predict_probabilities(self.pool_features)

# estimate classification for validation sample
self.validation_class, self.validation_prob = clf_instance.predict(self.validation_features)
self.validation_class = clf_instance.predict_class(self.validation_features)
self.validation_prob = clf_instance.predict_probabilities(self.validation_features)

if save_predictions:
id_name = self.identify_keywords()
Expand Down

0 comments on commit e4266f4

Please sign in to comment.