Skip to content

Commit

Permalink
Merge pull request #44 from LSSTDESC/more_query_strategy_tests
Browse files Browse the repository at this point in the history
Add tests for the ensemble-based query strategies
  • Loading branch information
jeremykubica authored Oct 16, 2024
2 parents a5c71dd + 163b79d commit 89c30c6
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/resspect/query_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def compute_entropy(ps: np.array):

def compute_qbd_mi_entropy(ensemble_probs: np.array):
"""
Calcualte the entropy of the average distribution from an ensemble of
Calculate the entropy of the average distribution from an ensemble of
distributions. Calculate the mutual information between the members in the
ensemble and the average distribution.
Parameters
----------
ensemble_probs: np.array
Probability from ensembles where the first dimension is number of unique
point, the second dimension is the number of ensemble members and the
points, the second dimension is the number of ensemble members and the
third dimension is the number of events.
Returns
Expand Down Expand Up @@ -452,9 +452,9 @@ def qbd_mi(ensemble_probs: np.array, test_ids: np.array,
'from number of objects in the test sample!')

# calculate distance to the decision boundary - only binary classification
entropies, mis = compute_qbd_mi_entropy(ensemble_probs)
_, mis = compute_qbd_mi_entropy(ensemble_probs)

# get indexes in increasing order
# get indexes in decreasing order
order = mis.argsort()[::-1]

# only allow objects in the query sample to be chosen
Expand Down Expand Up @@ -483,7 +483,9 @@ def qbd_mi(ensemble_probs: np.array, test_ids: np.array,
def qbd_entropy(ensemble_probs: np.array, test_ids: np.array,
queryable_ids: np.array, batch=1,
screen=False, query_thre=1.0) -> list:
"""Search for the sample with highest uncertainty in predicted class.
"""Search for the sample with highest entropy from the average predictions
of the ensembled classifiers. These can be instances where the classifiers
agree (but are uncertain about the class) or disagree.
Parameters
----------
Expand Down Expand Up @@ -519,9 +521,9 @@ def qbd_entropy(ensemble_probs: np.array, test_ids: np.array,
'from number of objects in the test sample!')

# calculate distance to the decision boundary - only binary classification
entropies, mis = compute_qbd_mi_entropy(ensemble_probs)
entropies, _ = compute_qbd_mi_entropy(ensemble_probs)

# get indexes in increasing order
# get indexes in decreasing order
order = entropies.argsort()[::-1]

# only allow objects in the query sample to be chosen
Expand Down
42 changes: 42 additions & 0 deletions tests/resspect/test_query_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest

from resspect.query_strategies import (
qbd_entropy,
qbd_mi,
random_sampling,
uncertainty_sampling,
uncertainty_sampling_entropy,
Expand Down Expand Up @@ -148,5 +150,45 @@ def test_uncertainty_sampling_margin():
assert np.array_equal(sample, [3, 1, 2])


def test_qbd_entropy():
"""Test the ensemble average prediction entropy sampling."""
test_ids = np.arange(0, 5)
queryable_ids = np.arange(0, 5)

# Probabilities coming out of the ensembles are 3-d matrices with dimensions:
# number of points (5), the number of ensemble members (3), and the number of events (2).
ensemble_probs = np.array(
[
[[1.00, 0.00], [0.95, 0.05], [0.99, 0.01]], # very low entropy (high agreement)
[[0.80, 0.20], [0.60, 0.40], [0.20, 0.80]], # high entropy (low agreement)
[[0.10, 0.90], [0.10, 0.90], [0.05, 0.95]], # low entropy (high agreement)
[[0.75, 0.25], [0.80, 0.20], [0.78, 0.22]], # medium entropy (high agreement)
[[0.50, 0.50], [0.50, 0.50], [0.49, 0.51]], # high entropy (high agreement)
]
)
sample = qbd_entropy(ensemble_probs, test_ids, queryable_ids, batch=5)
assert np.array_equal(sample, [4, 1, 3, 2, 0])


def test_qbd_mi():
"""Test the ensemble qbd_mi sampling."""
test_ids = np.arange(0, 5)
queryable_ids = np.arange(0, 5)

# Probabilities coming out of the ensembles are 3-d matrices with dimensions:
# number of points (5), the number of ensemble members (3), and the number of events (2).
ensemble_probs = np.array(
[
[[1.00, 0.00], [0.95, 0.05], [0.80, 0.20]],
[[0.80, 0.20], [0.60, 0.40], [0.20, 0.80]],
[[0.10, 0.90], [0.10, 0.90], [0.05, 0.95]],
[[0.75, 0.25], [0.80, 0.20], [0.78, 0.22]],
[[0.50, 0.50], [0.50, 0.50], [0.49, 0.51]],
]
)
sample = qbd_mi(ensemble_probs, test_ids, queryable_ids, batch=5)
assert np.array_equal(sample, [1, 0, 2, 3, 4])


if __name__ == '__main__':
pytest.main()

0 comments on commit 89c30c6

Please sign in to comment.