From 305d6ecc31ee1d41e2639cb4f38904925fceebfe Mon Sep 17 00:00:00 2001 From: Drew Oldag <47493171+drewoldag@users.noreply.github.com> Date: Fri, 1 Nov 2024 13:31:33 -0700 Subject: [PATCH] Fix a couple of classifier-related bugs. (#69) --- src/resspect/classifiers.py | 4 ++-- src/resspect/database.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/resspect/classifiers.py b/src/resspect/classifiers.py index 7420f73d..5df38085 100644 --- a/src/resspect/classifiers.py +++ b/src/resspect/classifiers.py @@ -141,7 +141,7 @@ class RandomForest(ResspectClassifier): def __init__(self, **kwargs): super().__init__(**kwargs) - self.n_estimators = kwargs.get('n_estimators', 100) + self.n_estimators = self.kwargs.pop('n_estimators', 100) self.classifier = RandomForestClassifier(n_estimators=self.n_estimators, **self.kwargs) @@ -169,7 +169,7 @@ class SVM(ResspectClassifier): def __init__(self, **kwargs): super().__init__(**kwargs) - self.probability = kwargs.get('probability', True) + self.probability = self.kwargs.pop('probability', True) self.classifier = SVC(probability=self.probability, **self.kwargs) diff --git a/src/resspect/database.py b/src/resspect/database.py index 2f76c47b..b3ec7f78 100644 --- a/src/resspect/database.py +++ b/src/resspect/database.py @@ -944,7 +944,7 @@ def classify(self, method: str, save_predictions=False, pred_dir=None, # if a pretrained model is available, load it, otherwise fit the model if pretrained_model_path is not None: - clf_instance.load(pretrained_model_path) + clf_instance.load_classifier(pretrained_model_path) else: clf_instance.fit(self.train_features, self.train_labels)