Skip to content

Commit

Permalink
Fix a couple of classifier-related bugs. (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag authored Nov 1, 2024
1 parent 444e421 commit 305d6ec
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/resspect/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class RandomForest(ResspectClassifier):
def __init__(self, **kwargs):
super().__init__(**kwargs)

self.n_estimators = kwargs.get('n_estimators', 100)
self.n_estimators = self.kwargs.pop('n_estimators', 100)
self.classifier = RandomForestClassifier(n_estimators=self.n_estimators, **self.kwargs)


Expand Down Expand Up @@ -169,7 +169,7 @@ class SVM(ResspectClassifier):
def __init__(self, **kwargs):
super().__init__(**kwargs)

self.probability = kwargs.get('probability', True)
self.probability = self.kwargs.pop('probability', True)
self.classifier = SVC(probability=self.probability, **self.kwargs)


Expand Down
2 changes: 1 addition & 1 deletion src/resspect/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def classify(self, method: str, save_predictions=False, pred_dir=None,

# if a pretrained model is available, load it, otherwise fit the model
if pretrained_model_path is not None:
clf_instance.load(pretrained_model_path)
clf_instance.load_classifier(pretrained_model_path)
else:
clf_instance.fit(self.train_features, self.train_labels)

Expand Down

0 comments on commit 305d6ec

Please sign in to comment.