From 163b79d9c6f44c17b9e7c0a9772c88fcd82ce229 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:10:32 -0400 Subject: [PATCH] Add tests for the ensemble-based query strategies --- src/resspect/query_strategies.py | 16 +++++----- tests/resspect/test_query_strategies.py | 42 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/resspect/query_strategies.py b/src/resspect/query_strategies.py index 9356736d..b8460285 100644 --- a/src/resspect/query_strategies.py +++ b/src/resspect/query_strategies.py @@ -47,7 +47,7 @@ def compute_entropy(ps: np.array): def compute_qbd_mi_entropy(ensemble_probs: np.array): """ - Calcualte the entropy of the average distribution from an ensemble of + Calculate the entropy of the average distribution from an ensemble of distributions. Calculate the mutual information between the members in the ensemble and the average distribution. @@ -55,7 +55,7 @@ def compute_qbd_mi_entropy(ensemble_probs: np.array): ---------- ensemble_probs: np.array Probability from ensembles where the first dimension is number of unique - point, the second dimension is the number of ensemble members and the + points, the second dimension is the number of ensemble members and the third dimension is the number of events. Returns @@ -452,9 +452,9 @@ def qbd_mi(ensemble_probs: np.array, test_ids: np.array, 'from number of objects in the test sample!') # calculate distance to the decision boundary - only binary classification - entropies, mis = compute_qbd_mi_entropy(ensemble_probs) + _, mis = compute_qbd_mi_entropy(ensemble_probs) - # get indexes in increasing order + # get indexes in decreasing order order = mis.argsort()[::-1] # only allow objects in the query sample to be chosen @@ -483,7 +483,9 @@ def qbd_mi(ensemble_probs: np.array, test_ids: np.array, def qbd_entropy(ensemble_probs: np.array, test_ids: np.array, queryable_ids: np.array, batch=1, screen=False, query_thre=1.0) -> list: - """Search for the sample with highest uncertainty in predicted class. + """Search for the sample with highest entropy from the average predictions + of the ensembled classifiers. These can be instances where the classifiers + agree (but are uncertain about the class) or disagree. Parameters ---------- @@ -519,9 +521,9 @@ def qbd_entropy(ensemble_probs: np.array, test_ids: np.array, 'from number of objects in the test sample!') # calculate distance to the decision boundary - only binary classification - entropies, mis = compute_qbd_mi_entropy(ensemble_probs) + entropies, _ = compute_qbd_mi_entropy(ensemble_probs) - # get indexes in increasing order + # get indexes in decreasing order order = entropies.argsort()[::-1] # only allow objects in the query sample to be chosen diff --git a/tests/resspect/test_query_strategies.py b/tests/resspect/test_query_strategies.py index a9f0fb87..c7f82d6e 100644 --- a/tests/resspect/test_query_strategies.py +++ b/tests/resspect/test_query_strategies.py @@ -5,6 +5,8 @@ import pytest from resspect.query_strategies import ( + qbd_entropy, + qbd_mi, random_sampling, uncertainty_sampling, uncertainty_sampling_entropy, @@ -148,5 +150,45 @@ def test_uncertainty_sampling_margin(): assert np.array_equal(sample, [3, 1, 2]) +def test_qbd_entropy(): + """Test the ensemble average prediction entropy sampling.""" + test_ids = np.arange(0, 5) + queryable_ids = np.arange(0, 5) + + # Probabilities coming out of the ensembles are 3-d matrices with dimensions: + # number of points (5), the number of ensemble members (3), and the number of events (2). + ensemble_probs = np.array( + [ + [[1.00, 0.00], [0.95, 0.05], [0.99, 0.01]], # very low entropy (high agreement) + [[0.80, 0.20], [0.60, 0.40], [0.20, 0.80]], # high entropy (low agreement) + [[0.10, 0.90], [0.10, 0.90], [0.05, 0.95]], # low entropy (high agreement) + [[0.75, 0.25], [0.80, 0.20], [0.78, 0.22]], # medium entropy (high agreement) + [[0.50, 0.50], [0.50, 0.50], [0.49, 0.51]], # high entropy (high agreement) + ] + ) + sample = qbd_entropy(ensemble_probs, test_ids, queryable_ids, batch=5) + assert np.array_equal(sample, [4, 1, 3, 2, 0]) + + +def test_qbd_mi(): + """Test the ensemble qbd_mi sampling.""" + test_ids = np.arange(0, 5) + queryable_ids = np.arange(0, 5) + + # Probabilities coming out of the ensembles are 3-d matrices with dimensions: + # number of points (5), the number of ensemble members (3), and the number of events (2). + ensemble_probs = np.array( + [ + [[1.00, 0.00], [0.95, 0.05], [0.80, 0.20]], + [[0.80, 0.20], [0.60, 0.40], [0.20, 0.80]], + [[0.10, 0.90], [0.10, 0.90], [0.05, 0.95]], + [[0.75, 0.25], [0.80, 0.20], [0.78, 0.22]], + [[0.50, 0.50], [0.50, 0.50], [0.49, 0.51]], + ] + ) + sample = qbd_mi(ensemble_probs, test_ids, queryable_ids, batch=5) + assert np.array_equal(sample, [1, 0, 2, 3, 4]) + + if __name__ == '__main__': pytest.main()