From d6a4a49f469c751cbfb8c869cbaff920bb4b2efe Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Mon, 14 Oct 2024 16:11:05 -0700 Subject: [PATCH 1/6] WIP - Initial commit with base ResspectClassifier class and early RandomForest class. --- src/resspect/classifiers.py | 66 +++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/src/resspect/classifiers.py b/src/resspect/classifiers.py index 641865b3..6b086646 100644 --- a/src/resspect/classifiers.py +++ b/src/resspect/classifiers.py @@ -17,7 +17,6 @@ import numpy as np from sklearn.ensemble import RandomForestClassifier -#from xgboost.sklearn import XGBClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.neural_network import MLPClassifier from sklearn.svm import SVC @@ -25,11 +24,64 @@ from sklearn.utils import resample from sklearn.utils.validation import check_is_fitted -__all__ = ['random_forest',#'gradient_boosted_trees', - 'knn', - 'mlp','svm','nbg', 'bootstrap_clf' - ] +__all__ = ['random_forest','knn','mlp','svm','nbg', 'bootstrap_clf'] +class ResspectClassifer(): + def __init__(self, train_features, train_labels, test_features, **kwargs): + + self.train_features = train_features + self.train_labels = train_labels + self.test_features = test_features + self.kwargs = kwargs + + self.n_labels = np.unique(self.train_labels).size + + #! Rename this after answering "Is shape[0] the number of objects or number of features/object?"" + self.num_test_data = self.test_features.shape[0] + self._n_ensembles = 10 + self.ensemble_probs = np.zeros((self.num_test_data, self.n_ensembles, self.n_labels)) + + @property + def n_ensembles(self): + return self._n_ensembles + + @n_ensembles.setter + def n_ensembles(self, value): + self._n_ensembles = value + self.ensemble_probs = np.zeros((self.num_test_data, self._n_ensembles, self.n_labels)) + + def bootstrap_ensemble(self, clf_function): + classifier_list = list() + for i in range(self.n_ensembles): + x_train, y_train = resample(self.train_features, self.train_labels) + _, class_prob, clf = clf_function(x_train, y_train, self.test_features, **self.kwargs) + + classifier_list.append((str(i), clf)) + self.ensemble_probs[:, i, :] = class_prob + + ensemble_clf = PreFitVotingClassifier(classifier_list) + class_prob = self.ensemble_probs.mean(axis=1) + predictions = np.argmax(class_prob, axis=1) + + return predictions, class_prob, self.ensemble_probs, ensemble_clf + + +class RandomForest(ResspectClassifer): + name = 'RandomForest' + def __init__(self, train_features, train_labels, test_features, **kwargs): + super().__init__(train_features, train_labels, test_features, **kwargs) + self.n_estimators = kwargs.get('n_estimators', 100) + + def __call__(self, **kwargs): + clf = RandomForestClassifier(n_estimators=self.n_estimators, **self.kwargs) + clf.fit(self.train_features, self.train_labels) + predictions = clf.predict(self.test_features) + prob = clf.predict_proba(self.test_features) + + return predictions, prob, clf + + def bootstrap(self, **kwargs): + return self.bootstrap_ensemble(self.__call__) def bootstrap_clf(clf_function, n_ensembles, train_features, train_labels, test_features, **kwargs): @@ -79,7 +131,7 @@ def bootstrap_clf(clf_function, n_ensembles, train_features, classifier_list.append((str(i), clf)) ensemble_probs[:, i, :] = class_prob - ensemble_clf = PreFitVotingClassifier(classifier_list, voting='soft') #Must use soft voting + ensemble_clf = PreFitVotingClassifier(classifier_list) class_prob = ensemble_probs.mean(axis=1) predictions = np.argmax(class_prob, axis=1) @@ -294,7 +346,7 @@ def nbg(train_features: np.array, train_labels: np.array, class PreFitVotingClassifier(object): """Stripped-down version of VotingClassifier that uses prefit estimators""" - def __init__(self, estimators, voting='hard', weights=None): + def __init__(self, estimators, voting='soft', weights=None): self.estimators = [e[1] for e in estimators] self.named_estimators = dict(estimators) self.voting = voting From 64df0246e9054a3935f6bd47c215797f251e0310 Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Tue, 15 Oct 2024 15:59:44 -0700 Subject: [PATCH 2/6] Basic classifier plugin system working. RandomForestClassifier is in place. Need to create the other built-in classifer classes still. --- src/resspect/__init__.py | 5 +- src/resspect/classifier_registry.py | 44 ++++++++ src/resspect/classifiers.py | 146 +++++++++++++++++---------- src/resspect/database.py | 150 +++++++++++++++++----------- src/resspect/plugin_utils.py | 84 ++++++++++++++++ 5 files changed, 316 insertions(+), 113 deletions(-) create mode 100644 src/resspect/classifier_registry.py create mode 100644 src/resspect/plugin_utils.py diff --git a/src/resspect/__init__.py b/src/resspect/__init__.py index 973020d1..e18dd5de 100644 --- a/src/resspect/__init__.py +++ b/src/resspect/__init__.py @@ -45,6 +45,7 @@ from .query_budget_strategies import * from .bump import * from .feature_extractors.malanchev import * +from .classifier_registry import * import importlib.metadata @@ -105,4 +106,6 @@ 'svm', 'time_domain_loop', 'uncertainty_sampling', - 'update_matrix'] \ No newline at end of file + 'update_matrix'] + +classifier_registry.register_builtin_classifiers() diff --git a/src/resspect/classifier_registry.py b/src/resspect/classifier_registry.py new file mode 100644 index 00000000..f1af2aab --- /dev/null +++ b/src/resspect/classifier_registry.py @@ -0,0 +1,44 @@ +from resspect.plugin_utils import get_or_load_class +from resspect.classifiers import ResspectClassifer + +__all__ = ["CLASSIFIER_REGISTRY", "register_builtin_classifiers", "fetch_classifier_class"] + +CLASSIFIER_REGISTRY = {} + +def register_builtin_classifiers(): + """Add all built-in classifiers to the registry.""" + subclasses = ResspectClassifer.__subclasses__() + for subclass in subclasses: + CLASSIFIER_REGISTRY[subclass.__name__] = subclass + + +def fetch_classifier_class(classifier_name: str) -> type: + """Fetch the classifier class from the registry. + + Parameters + ---------- + classifier_name : str + The name of the classifier class to retrieve. This should either be the + name of the class or the import specification for the class. + + Returns + ------- + type + The classifier class. + + Raises + ------ + ValueError + If a built-in classifier was requested, but not found in the registry. + ValueError + If no classifier was specified in the runtime configuration. + """ + + clf_class = None + + try: + clf_class = get_or_load_class(classifier_name, CLASSIFIER_REGISTRY) + except ValueError as exc: + raise ValueError(f"Error fetching class: {classifier_name}") from exc + + return clf_class diff --git a/src/resspect/classifiers.py b/src/resspect/classifiers.py index 6b086646..e9e41d6b 100644 --- a/src/resspect/classifiers.py +++ b/src/resspect/classifiers.py @@ -24,22 +24,44 @@ from sklearn.utils import resample from sklearn.utils.validation import check_is_fitted -__all__ = ['random_forest','knn','mlp','svm','nbg', 'bootstrap_clf'] +__all__ = [ + 'random_forest', + 'knn', + 'mlp', + 'svm', + 'nbg', + 'bootstrap_clf', + 'ResspectClassifer', + 'RandomForest', + ] class ResspectClassifer(): + """Base class that all built-in RESSPECT classifiers will inherit from.""" + def __init__(self, train_features, train_labels, test_features, **kwargs): + """Base initializer for all RESSPECT classifiers. + Parameters + ---------- + train_features : array-like + _description_ + train_labels : array-like + _description_ + test_features : array-like + _description_ + """ self.train_features = train_features self.train_labels = train_labels self.test_features = test_features self.kwargs = kwargs - self.n_labels = np.unique(self.train_labels).size - #! Rename this after answering "Is shape[0] the number of objects or number of features/object?"" self.num_test_data = self.test_features.shape[0] self._n_ensembles = 10 - self.ensemble_probs = np.zeros((self.num_test_data, self.n_ensembles, self.n_labels)) + self.n_labels = np.unique(self.train_labels).size + self.ensemble_probs = np.zeros((self.num_test_data, self._n_ensembles, self.n_labels)) + + self.classifier = None @property def n_ensembles(self): @@ -50,11 +72,73 @@ def n_ensembles(self, value): self._n_ensembles = value self.ensemble_probs = np.zeros((self.num_test_data, self._n_ensembles, self.n_labels)) + def __call__(self): + """Allows the user to call the class instance as a function. + e.g. clf = SomeClassifier() + predictions, _, _ = clf() + """ + return self.predict(self.train_features, self.train_labels, self.test_features) + + def predict(self, train_features, train_labels, test_features): + """Train and predict using the classifier. + + Parameters + ---------- + train_features : array-like + The features used for training, [n_samples, m_features]. + train_labels : array-like + The training labels, [n_samples]. + test_features : array-like + The features used for testing, [n_samples, m_features]. + + Returns + ------- + tuple(predictions, prob, classifier_instance) + The classes and probabilities for the test sample. + """ + self.classifier.fit(train_features, train_labels) + predictions = self.classifier.predict(test_features) + prob = self.classifier.predict_proba(test_features) + + return predictions, prob, self.classifier + + def bootstrap(self): + """Convenience method that can be overridden by subclasses. Calls the + bootstrap_ensemble method with the predict method as an argument. + + Returns + ------- + tuple(predictions, prob, ensemble_probs, ensemble_clf) + The classes and probabilities for the test sample. + """ + return self.bootstrap_ensemble(self.predict) + def bootstrap_ensemble(self, clf_function): + """Create an ensemble of predictions by resampling the training data used + to instantiate the classifier. Define the ensemble size by specifying the + value for `n_ensembles`. + + e.g.: + ``` + clf = SomeClassifier() + clf.n_ensembles = 10 + clf.bootstrap_ensemble(clf.predict) + ``` + + Parameters + ---------- + clf_function : Callable + The function used to and predict with the classifier. + + Returns + ------- + tuple(predictions, prob, ensemble_probs, ensemble_clf) + The classes and probabilities for the test sample. + """ classifier_list = list() for i in range(self.n_ensembles): x_train, y_train = resample(self.train_features, self.train_labels) - _, class_prob, clf = clf_function(x_train, y_train, self.test_features, **self.kwargs) + _, class_prob, clf = clf_function(x_train, y_train, self.test_features) classifier_list.append((str(i), clf)) self.ensemble_probs[:, i, :] = class_prob @@ -67,21 +151,14 @@ def bootstrap_ensemble(self, clf_function): class RandomForest(ResspectClassifer): - name = 'RandomForest' + """RESSPECT-specific version of the sklearn RandomForestClassifier.""" + def __init__(self, train_features, train_labels, test_features, **kwargs): super().__init__(train_features, train_labels, test_features, **kwargs) - self.n_estimators = kwargs.get('n_estimators', 100) - def __call__(self, **kwargs): - clf = RandomForestClassifier(n_estimators=self.n_estimators, **self.kwargs) - clf.fit(self.train_features, self.train_labels) - predictions = clf.predict(self.test_features) - prob = clf.predict_proba(self.test_features) - - return predictions, prob, clf + self.n_estimators = kwargs.get('n_estimators', 100) + self.classifier = RandomForestClassifier(n_estimators=self.n_estimators, **self.kwargs) - def bootstrap(self, **kwargs): - return self.bootstrap_ensemble(self.__call__) def bootstrap_clf(clf_function, n_ensembles, train_features, train_labels, test_features, **kwargs): @@ -171,43 +248,6 @@ def random_forest(train_features: np.array, train_labels: np.array, prob = clf.predict_proba(test_features) # get probabilities return predictions, prob, clf - -####################################################################### -###### we need to find a non-bugged version of xgboost ############## - -#def gradient_boosted_trees(train_features: np.array, -# train_labels: np.array, -# test_features: np.array, **kwargs): - """Gradient Boosted Trees classifier. - - Parameters - ---------- - train_features : np.array - Training sample features. - train_labels: np.array - Training sample classes. - test_features: np.array - Test sample features. - kwargs: extra parameters - All parameters allowed by sklearn.XGBClassifier - - Returns - ------- - predictions: np.array - Predicted classes. - prob: np.array - Classification probability for all objects, [pIa, pnon-Ia]. - """ - - #create classifier instance -# clf = XGBClassifier(**kwargs) - -# clf.fit(train_features, train_labels) # train -# predictions = clf.predict(test_features) # predict -# prob = clf.predict_proba(test_features) # get probabilities - -# return predictions, prob, clf -######################################################################### def knn(train_features: np.array, train_labels: np.array, test_features: np.array, **kwargs): diff --git a/src/resspect/database.py b/src/resspect/database.py index c415db88..25510136 100644 --- a/src/resspect/database.py +++ b/src/resspect/database.py @@ -27,7 +27,7 @@ from resspect.query_strategies import * from resspect.query_budget_strategies import * from resspect.metrics import get_snpcc_metric - +from resspect.classifier_registry import fetch_classifier_class __all__ = ['DataBase'] @@ -946,37 +946,50 @@ def classify(self, method: str, save_predictions=False, pred_dir=None, print(' ... train_labels: ', self.train_labels.shape) print(' ... pool_features: ', self.pool_features.shape) - if method == 'RandomForest': - self.predicted_class, self.classprob, self.classifier = \ - random_forest(self.train_features, self.train_labels, - self.pool_features, **kwargs) - elif method == 'GradientBoostedTrees': - raise ValueError("GradientBoostedTrees is currently unimplemented.") - # TODO: Restore once GradientBoostedTrees is fixed. - # self.predicted_class, self.classprob, self.classifier = \ - # gradient_boosted_trees(self.train_features, self.train_labels, - # self.pool_features, **kwargs) - elif method == 'KNN': - self.predicted_class, self.classprob, self.classifier = \ - knn(self.train_features, self.train_labels, - self.pool_features, **kwargs) - elif method == 'MLP': - self.predicted_class, self.classprob, self.classifier = \ - mlp(self.train_features, self.train_labels, - self.pool_features, **kwargs) - elif method == 'SVM': - self.predicted_class, self.classprob, self.classifier = \ - svm(self.train_features, self.train_labels, - self.pool_features, **kwargs) - elif method == 'NB': - self.predicted_class, self.classprob, self.classifier = \ - nbg(self.train_features, self.train_labels, - self.pool_features, **kwargs) - else: - raise ValueError( - "The only classifiers implemented are 'RandomForest', 'KNN', 'MLP', " - "'SVM' and 'NB'.\nFeel free to add other options." - ) + clf_class = fetch_classifier_class(method) + if clf_class is None: + raise ValueError(f'Classifier, {method} not recognized!') + + clf_instance = clf_class( + self.train_features, + self.train_labels, + self.pool_features, + **kwargs + ) + + self.predicted_class, self.classprob, self.classifier = clf_instance() + + # if method == 'RandomForest': + # self.predicted_class, self.classprob, self.classifier = \ + # random_forest(self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # elif method == 'GradientBoostedTrees': + # raise ValueError("GradientBoostedTrees is currently unimplemented.") + # # TODO: Restore once GradientBoostedTrees is fixed. + # # self.predicted_class, self.classprob, self.classifier = \ + # # gradient_boosted_trees(self.train_features, self.train_labels, + # # self.pool_features, **kwargs) + # elif method == 'KNN': + # self.predicted_class, self.classprob, self.classifier = \ + # knn(self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # elif method == 'MLP': + # self.predicted_class, self.classprob, self.classifier = \ + # mlp(self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # elif method == 'SVM': + # self.predicted_class, self.classprob, self.classifier = \ + # svm(self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # elif method == 'NB': + # self.predicted_class, self.classprob, self.classifier = \ + # nbg(self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # else: + # raise ValueError( + # "The only classifiers implemented are 'RandomForest', 'KNN', 'MLP', " + # "'SVM' and 'NB'.\nFeel free to add other options." + # ) # estimate classification for validation sample self.validation_class = \ @@ -1028,39 +1041,58 @@ def classify_bootstrap(self, method: str, save_predictions=False, pred_dir=None, print(' ... train_labels: ', self.train_labels.shape) print(' ... pool_features: ', self.pool_features.shape) + clf_class = fetch_classifier_class(method) + if clf_class is None: + raise ValueError(f'Classifier, {method} not recognized!') + + clf_instance = clf_class( + self.train_features, + self.train_labels, + self.pool_features, + **kwargs + ) + + clf_instance.n_ensembles = n_ensembles + + self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = clf_instance.bootstrap() + if method == 'RandomForest': + rf = RandomForest(self.train_features, self.train_labels, self.pool_features, **kwargs) + rf.n_ensembles = n_ensembles + self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = rf.bootstrap() + self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ bootstrap_clf(random_forest, n_ensembles, self.train_features, self.train_labels, self.pool_features, **kwargs) - elif method == 'GradientBoostedTrees': - self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - bootstrap_clf(gradient_boosted_trees, n_ensembles, - self.train_features, self.train_labels, - self.pool_features, **kwargs) - elif method == 'KNN': - self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - bootstrap_clf(knn, n_ensembles, - self.train_features, self.train_labels, - self.pool_features, **kwargs) - elif method == 'MLP': - self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - bootstrap_clf(mlp, n_ensembles, - self.train_features, self.train_labels, - self.pool_features, **kwargs) - elif method == 'SVM': - self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - bootstrap_clf(svm, n_ensembles, - self.train_features, self.train_labels, - self.pool_features, **kwargs) - elif method == 'NB': - self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - bootstrap_clf(nbg, n_ensembles, - self.train_features, self.train_labels, - self.pool_features, **kwargs) - else: - raise ValueError('Classifier not recognized!') + # elif method == 'GradientBoostedTrees': + # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ + # bootstrap_clf(gradient_boosted_trees, n_ensembles, + # self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # elif method == 'KNN': + # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ + # bootstrap_clf(knn, n_ensembles, + # self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # elif method == 'MLP': + # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ + # bootstrap_clf(mlp, n_ensembles, + # self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # elif method == 'SVM': + # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ + # bootstrap_clf(svm, n_ensembles, + # self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # elif method == 'NB': + # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ + # bootstrap_clf(nbg, n_ensembles, + # self.train_features, self.train_labels, + # self.pool_features, **kwargs) + # else: + # raise ValueError('Classifier not recognized!') self.validation_class = \ self.classifier.predict(self.validation_features) diff --git a/src/resspect/plugin_utils.py b/src/resspect/plugin_utils.py new file mode 100644 index 00000000..b2925e3a --- /dev/null +++ b/src/resspect/plugin_utils.py @@ -0,0 +1,84 @@ +import importlib + + +def get_or_load_class(class_name: str, registry: dict) -> type: + """Given the name of a class and a registry dictionary, attempt to return + the requested class either from the registry or by dynamically importing it. + + Parameters + ---------- + class_name : str + The name of the class to be returned. + registry : dict + The registry dictionary of : pairs. + + Returns + ------- + type + The returned class to be instantiated + + Raises + ------ + ValueError + User failed to specify a class to load in the runtime configuration. No + `name` key was found in the config. + """ + + if class_name in registry: + returned_class = registry[class_name] + else: + returned_class = import_module_from_string(class_name) + + return returned_class + + +def import_module_from_string(module_path: str) -> type: + """Dynamically import a module from a string. + + Parameters + ---------- + module_path : str + The import specification for the class. Should be of the form: + "module.submodule.class_name" + + Returns + ------- + returned_cls : type + The class to be instantiated. + + Raises + ------ + AttributeError + If the class is not found in the module that is loaded. + ModuleNotFoundError + If the module is not found using the provided import specification. + """ + + module_name, class_name = module_path.rsplit(".", 1) + returned_cls = None + + try: + # Attempt to find the module spec, i.e. `module.submodule.`. + # Will raise exception if `submodule`, 'subsubmodule', etc. is not found. + importlib.util.find_spec(module_name) + + # `importlib.util.find_spec()` will return None if `module_name` is not found. + if (importlib.util.find_spec(module_name)) is not None: + # Import the requested module + imported_module = importlib.import_module(module_name) + + # Check if the requested class is in the imported module + if hasattr(imported_module, class_name): + returned_cls = getattr(imported_module, class_name) + else: + raise AttributeError(f"The class {class_name} was not found in module {module_name}") + + # Raise an exception if the base module of the spec is not found + else: + raise ModuleNotFoundError(f"Module {module_name} not found") + + # Exception raised when a submodule of the spec is not found + except ModuleNotFoundError as exc: + raise ModuleNotFoundError(f"Module {module_name} not found") from exc + + return returned_cls From 2af5c7b78d528e52f63debe090e92cadf3e60316 Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:03:15 -0700 Subject: [PATCH 3/6] Switch to using __init_subclass__ instead of the explicit registration function call in the top level __init__.py. --- src/resspect/__init__.py | 2 -- src/resspect/classifier_registry.py | 12 ++---------- src/resspect/classifiers.py | 10 ++++++++++ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/resspect/__init__.py b/src/resspect/__init__.py index e18dd5de..cd60d27d 100644 --- a/src/resspect/__init__.py +++ b/src/resspect/__init__.py @@ -107,5 +107,3 @@ 'time_domain_loop', 'uncertainty_sampling', 'update_matrix'] - -classifier_registry.register_builtin_classifiers() diff --git a/src/resspect/classifier_registry.py b/src/resspect/classifier_registry.py index f1af2aab..c22b3aab 100644 --- a/src/resspect/classifier_registry.py +++ b/src/resspect/classifier_registry.py @@ -1,15 +1,7 @@ from resspect.plugin_utils import get_or_load_class -from resspect.classifiers import ResspectClassifer +from resspect.classifiers import CLASSIFIER_REGISTRY -__all__ = ["CLASSIFIER_REGISTRY", "register_builtin_classifiers", "fetch_classifier_class"] - -CLASSIFIER_REGISTRY = {} - -def register_builtin_classifiers(): - """Add all built-in classifiers to the registry.""" - subclasses = ResspectClassifer.__subclasses__() - for subclass in subclasses: - CLASSIFIER_REGISTRY[subclass.__name__] = subclass +__all__ = ["fetch_classifier_class"] def fetch_classifier_class(classifier_name: str) -> type: diff --git a/src/resspect/classifiers.py b/src/resspect/classifiers.py index e9e41d6b..d51bf9b3 100644 --- a/src/resspect/classifiers.py +++ b/src/resspect/classifiers.py @@ -33,8 +33,11 @@ 'bootstrap_clf', 'ResspectClassifer', 'RandomForest', + 'CLASSIFIER_REGISTRY', ] +CLASSIFIER_REGISTRY = {} + class ResspectClassifer(): """Base class that all built-in RESSPECT classifiers will inherit from.""" @@ -63,6 +66,13 @@ def __init__(self, train_features, train_labels, test_features, **kwargs): self.classifier = None + def __init_subclass__(cls): + """Register all subclasses of ResspectClassifer in the CLASSIFIER_REGISTRY.""" + if cls.__name__ in CLASSIFIER_REGISTRY: + raise ValueError(f"Duplicate classifier name: {cls.__name__}") + + CLASSIFIER_REGISTRY[cls.__name__] = cls + @property def n_ensembles(self): return self._n_ensembles From 04fb1ab8b22df9a8ef40da270c8d7a3ca3084ce5 Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:37:48 -0700 Subject: [PATCH 4/6] Adding remaining built-in classifier classes. --- src/resspect/classifiers.py | 39 ++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/resspect/classifiers.py b/src/resspect/classifiers.py index d51bf9b3..84603517 100644 --- a/src/resspect/classifiers.py +++ b/src/resspect/classifiers.py @@ -170,6 +170,43 @@ def __init__(self, train_features, train_labels, test_features, **kwargs): self.classifier = RandomForestClassifier(n_estimators=self.n_estimators, **self.kwargs) +class KNN(ResspectClassifer): + """RESSPECT-specific version of the sklearn KNeighborsClassifier.""" + + def __init__(self, train_features, train_labels, test_features, **kwargs): + super().__init__(train_features, train_labels, test_features, **kwargs) + + self.classifier = KNeighborsClassifier(**self.kwargs) + + +class MLP(ResspectClassifer): + """RESSPECT-specific version of the sklearn MLPClassifier.""" + + def __init__(self, train_features, train_labels, test_features, **kwargs): + super().__init__(train_features, train_labels, test_features, **kwargs) + + self.classifier = MLPClassifier(**self.kwargs) + + +class SVM(ResspectClassifer): + """RESSPECT-specific version of the sklearn SVC.""" + + def __init__(self, train_features, train_labels, test_features, **kwargs): + super().__init__(train_features, train_labels, test_features, **kwargs) + + self.probability = kwargs.get('probability', True) + self.classifier = SVC(probability=self.probability, **self.kwargs) + + +class NBG(ResspectClassifer): + """RESSPECT-specific version of the sklearn GaussianNB.""" + + def __init__(self, train_features, train_labels, test_features, **kwargs): + super().__init__(train_features, train_labels, test_features, **kwargs) + + self.classifier = GaussianNB(**self.kwargs) + + def bootstrap_clf(clf_function, n_ensembles, train_features, train_labels, test_features, **kwargs): """ @@ -358,7 +395,7 @@ def svm(train_features: np.array, train_labels: np.array, prob = clf.predict_proba(test_features) # get probabilities return predictions, prob, clf - + def nbg(train_features: np.array, train_labels: np.array, test_features: np.array, **kwargs): From b596e292b146b0bd837dbdc706c07fb3f89b343a Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:39:42 -0700 Subject: [PATCH 5/6] Remove deprecated if/else ladders in `database.py:classify` and `database.py:classify_bootstrap`. --- src/resspect/database.py | 70 ---------------------------------------- 1 file changed, 70 deletions(-) diff --git a/src/resspect/database.py b/src/resspect/database.py index 25510136..0859e42b 100644 --- a/src/resspect/database.py +++ b/src/resspect/database.py @@ -959,38 +959,6 @@ def classify(self, method: str, save_predictions=False, pred_dir=None, self.predicted_class, self.classprob, self.classifier = clf_instance() - # if method == 'RandomForest': - # self.predicted_class, self.classprob, self.classifier = \ - # random_forest(self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # elif method == 'GradientBoostedTrees': - # raise ValueError("GradientBoostedTrees is currently unimplemented.") - # # TODO: Restore once GradientBoostedTrees is fixed. - # # self.predicted_class, self.classprob, self.classifier = \ - # # gradient_boosted_trees(self.train_features, self.train_labels, - # # self.pool_features, **kwargs) - # elif method == 'KNN': - # self.predicted_class, self.classprob, self.classifier = \ - # knn(self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # elif method == 'MLP': - # self.predicted_class, self.classprob, self.classifier = \ - # mlp(self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # elif method == 'SVM': - # self.predicted_class, self.classprob, self.classifier = \ - # svm(self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # elif method == 'NB': - # self.predicted_class, self.classprob, self.classifier = \ - # nbg(self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # else: - # raise ValueError( - # "The only classifiers implemented are 'RandomForest', 'KNN', 'MLP', " - # "'SVM' and 'NB'.\nFeel free to add other options." - # ) - # estimate classification for validation sample self.validation_class = \ self.classifier.predict(self.validation_features) @@ -1056,44 +1024,6 @@ def classify_bootstrap(self, method: str, save_predictions=False, pred_dir=None, self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = clf_instance.bootstrap() - if method == 'RandomForest': - rf = RandomForest(self.train_features, self.train_labels, self.pool_features, **kwargs) - rf.n_ensembles = n_ensembles - self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = rf.bootstrap() - - self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - bootstrap_clf(random_forest, n_ensembles, - self.train_features, self.train_labels, - self.pool_features, **kwargs) - - # elif method == 'GradientBoostedTrees': - # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - # bootstrap_clf(gradient_boosted_trees, n_ensembles, - # self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # elif method == 'KNN': - # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - # bootstrap_clf(knn, n_ensembles, - # self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # elif method == 'MLP': - # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - # bootstrap_clf(mlp, n_ensembles, - # self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # elif method == 'SVM': - # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - # bootstrap_clf(svm, n_ensembles, - # self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # elif method == 'NB': - # self.predicted_class, self.classprob, self.ensemble_probs, self.classifier = \ - # bootstrap_clf(nbg, n_ensembles, - # self.train_features, self.train_labels, - # self.pool_features, **kwargs) - # else: - # raise ValueError('Classifier not recognized!') - self.validation_class = \ self.classifier.predict(self.validation_features) self.validation_prob = \ From ace29f60a9b82388634cee4ccce4c569b4efe787 Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:57:16 -0700 Subject: [PATCH 6/6] Consolidated the remaining classifier_registry.py function into plugin_utils. Added tests. --- src/resspect/__init__.py | 2 +- src/resspect/classifier_registry.py | 36 --------------- src/resspect/database.py | 2 +- src/resspect/plugin_utils.py | 34 ++++++++++++++- tests/resspect/test_plugin_utils.py | 68 +++++++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 39 deletions(-) delete mode 100644 src/resspect/classifier_registry.py create mode 100644 tests/resspect/test_plugin_utils.py diff --git a/src/resspect/__init__.py b/src/resspect/__init__.py index cd60d27d..17374e78 100644 --- a/src/resspect/__init__.py +++ b/src/resspect/__init__.py @@ -45,7 +45,7 @@ from .query_budget_strategies import * from .bump import * from .feature_extractors.malanchev import * -from .classifier_registry import * +from .plugin_utils import * import importlib.metadata diff --git a/src/resspect/classifier_registry.py b/src/resspect/classifier_registry.py deleted file mode 100644 index c22b3aab..00000000 --- a/src/resspect/classifier_registry.py +++ /dev/null @@ -1,36 +0,0 @@ -from resspect.plugin_utils import get_or_load_class -from resspect.classifiers import CLASSIFIER_REGISTRY - -__all__ = ["fetch_classifier_class"] - - -def fetch_classifier_class(classifier_name: str) -> type: - """Fetch the classifier class from the registry. - - Parameters - ---------- - classifier_name : str - The name of the classifier class to retrieve. This should either be the - name of the class or the import specification for the class. - - Returns - ------- - type - The classifier class. - - Raises - ------ - ValueError - If a built-in classifier was requested, but not found in the registry. - ValueError - If no classifier was specified in the runtime configuration. - """ - - clf_class = None - - try: - clf_class = get_or_load_class(classifier_name, CLASSIFIER_REGISTRY) - except ValueError as exc: - raise ValueError(f"Error fetching class: {classifier_name}") from exc - - return clf_class diff --git a/src/resspect/database.py b/src/resspect/database.py index 0859e42b..f28c8c9c 100644 --- a/src/resspect/database.py +++ b/src/resspect/database.py @@ -27,7 +27,7 @@ from resspect.query_strategies import * from resspect.query_budget_strategies import * from resspect.metrics import get_snpcc_metric -from resspect.classifier_registry import fetch_classifier_class +from resspect.plugin_utils import fetch_classifier_class __all__ = ['DataBase'] diff --git a/src/resspect/plugin_utils.py b/src/resspect/plugin_utils.py index b2925e3a..3c8690dd 100644 --- a/src/resspect/plugin_utils.py +++ b/src/resspect/plugin_utils.py @@ -1,5 +1,5 @@ import importlib - +from resspect.classifiers import CLASSIFIER_REGISTRY def get_or_load_class(class_name: str, registry: dict) -> type: """Given the name of a class and a registry dictionary, attempt to return @@ -82,3 +82,35 @@ def import_module_from_string(module_path: str) -> type: raise ModuleNotFoundError(f"Module {module_name} not found") from exc return returned_cls + + +def fetch_classifier_class(classifier_name: str) -> type: + """Fetch the classifier class from the registry. + + Parameters + ---------- + classifier_name : str + The name of the classifier class to retrieve. This should either be the + name of the class or the import specification for the class. + + Returns + ------- + type + The classifier class. + + Raises + ------ + ValueError + If a built-in classifier was requested, but not found in the registry. + ValueError + If no classifier was specified in the runtime configuration. + """ + + clf_class = None + + try: + clf_class = get_or_load_class(classifier_name, CLASSIFIER_REGISTRY) + except ValueError as exc: + raise ValueError(f"Error fetching class: {classifier_name}") from exc + + return clf_class diff --git a/tests/resspect/test_plugin_utils.py b/tests/resspect/test_plugin_utils.py new file mode 100644 index 00000000..afd52d4f --- /dev/null +++ b/tests/resspect/test_plugin_utils.py @@ -0,0 +1,68 @@ +import pytest +from resspect.plugin_utils import import_module_from_string, fetch_classifier_class + + +def test_import_module_from_string(): + """Test the import_module_from_string function.""" + module_path = "builtins.BaseException" + + returned_cls = import_module_from_string(module_path) + + assert returned_cls.__name__ == "BaseException" + + +def test_import_module_from_string_no_base_module(): + """Test that the import_module_from_string function raises an error when + the base module is not found.""" + + module_path = "nonexistent.BaseException" + + with pytest.raises(ModuleNotFoundError) as excinfo: + import_module_from_string(module_path) + + assert "Module nonexistent not found" in str(excinfo.value) + + +def test_import_module_from_string_no_submodule(): + """Test that the import_module_from_string function raises an error when + a submodule is not found.""" + + module_path = "builtins.nonexistent.BaseException" + + with pytest.raises(ModuleNotFoundError) as excinfo: + import_module_from_string(module_path) + + assert "Module builtins.nonexistent not found" in str(excinfo.value) + + +def test_import_module_from_string_no_class(): + """Test that the import_module_from_string function raises an error when + a class is not found.""" + + module_path = "builtins.Nonexistent" + + with pytest.raises(AttributeError) as excinfo: + import_module_from_string(module_path) + + assert "The class Nonexistent was not found" in str(excinfo.value) + + +def test_fetch_classifier_class(): + """Test the fetch_classifier_class function.""" + requested_class = "builtins.BaseException" + + returned_cls = fetch_classifier_class(requested_class) + + assert returned_cls.__name__ == "BaseException" + + +def test_fetch_classifier_class_not_in_registry(): + """Test that an exception is raised when a model is requested that is not in the registry.""" + + requested_class = "Nonexistent" + + with pytest.raises(ValueError) as excinfo: + fetch_classifier_class(requested_class) + + assert "Error fetching class: Nonexistent" in str(excinfo.value) +