Skip to content

Commit

Permalink
Initial commit for pluggable query_strategies. (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag authored Oct 30, 2024
1 parent 71af541 commit 477015a
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 70 deletions.
75 changes: 21 additions & 54 deletions src/resspect/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from resspect.query_strategies import *
from resspect.query_budget_strategies import *
from resspect.metrics import get_snpcc_metric
from resspect.plugin_utils import fetch_classifier_class
from resspect.plugin_utils import fetch_classifier_class, fetch_query_strategy_class

__all__ = ['DataBase']

Expand Down Expand Up @@ -1265,7 +1265,7 @@ def make_query_budget(self, budgets, strategy='UncSampling', screen=False) -> li
return query_indx

def make_query(self, strategy='UncSampling', batch=1,
screen=False, queryable=False, query_thre=1.0) -> list:
screen=False, queryable=False, query_threshold=1.0) -> list:
"""Identify new object to be added to the training sample.
Parameters
Expand All @@ -1282,7 +1282,7 @@ def make_query(self, strategy='UncSampling', batch=1,
queryable: bool (optional)
If True, consider only queryable objects.
Default is False.
query_thre: float (optional)
query_threshold: float (optional)
Percentile threshold where a query is considered worth it.
Default is 1 (no limit).
screen: bool (optional)
Expand All @@ -1305,57 +1305,24 @@ def make_query(self, strategy='UncSampling', batch=1,

id_name = self.identify_keywords()

if strategy == 'UncSampling':
query_indx = uncertainty_sampling(class_prob=self.classprob,
queryable_ids=self.queryable_ids,
test_ids=self.pool_metadata[id_name].values,
batch=batch, screen=screen,
query_thre=query_thre)


elif strategy == 'UncSamplingEntropy':
query_indx = uncertainty_sampling_entropy(class_prob=self.classprob,
queryable_ids=self.queryable_ids,
test_ids=self.pool_metadata[id_name].values,
batch=batch, screen=screen,
query_thre=query_thre)

elif strategy == 'UncSamplingLeastConfident':
query_indx = uncertainty_sampling_least_confident(class_prob=self.classprob,
queryable_ids=self.queryable_ids,
test_ids=self.pool_metadata[id_name].values,
batch=batch, screen=screen,
query_thre=query_thre)

elif strategy == 'UncSamplingMargin':
query_indx = uncertainty_sampling_margin(class_prob=self.classprob,
queryable_ids=self.queryable_ids,
test_ids=self.pool_metadata[id_name].values,
batch=batch, screen=screen,
query_thre=query_thre)
return query_indx
elif strategy == 'QBDMI':
query_indx = qbd_mi(ensemble_probs=self.ensemble_probs,
queryable_ids=self.queryable_ids,
test_ids=self.pool_metadata[id_name].values,
batch=batch, screen=screen,
query_thre=query_thre)

elif strategy =='QBDEntropy':
query_indx = qbd_entropy(ensemble_probs=self.ensemble_probs,
queryable_ids=self.queryable_ids,
test_ids=self.pool_metadata[id_name].values,
batch=batch, screen=screen,
query_thre=query_thre)

elif strategy == 'RandomSampling':
query_indx = random_sampling(queryable_ids=self.queryable_ids,
test_ids=self.pool_metadata[id_name].values,
queryable=queryable, batch=batch,
query_thre=query_thre, screen=screen)

else:
raise ValueError('Invalid strategy.')
# retrieve and instantiate the query strategy class
query_strategy_class = fetch_query_strategy_class(strategy)
query_strategy = query_strategy_class(
queryable_ids=self.queryable_ids,
test_ids=self.pool_metadata[id_name].values,
batch=batch,
screen=screen,
query_threshold=query_threshold,
queryable=queryable
)

# Use the `requires_ensemble` flag to determine which probabilities to use
input_probabilities = self.classprob
if query_strategy.requires_ensemble:
input_probabilities = self.ensemble_probs

# get the query index from the strategy
query_indx = query_strategy.sample(input_probabilities)

if screen:
print(' ... queried obj id: ', self.pool_metadata[id_name].values[query_indx[0]])
Expand Down
52 changes: 38 additions & 14 deletions src/resspect/plugin_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
from resspect.classifiers import CLASSIFIER_REGISTRY
from resspect.classifiers import CLASSIFIER_REGISTRY, ResspectClassifier
from resspect.query_strategies import QUERY_STRATEGY_REGISTRY, QueryStrategy

def get_or_load_class(class_name: str, registry: dict) -> type:
"""Given the name of a class and a registry dictionary, attempt to return
Expand All @@ -24,10 +25,15 @@ def get_or_load_class(class_name: str, registry: dict) -> type:
`name` key was found in the config.
"""

if class_name in registry:
returned_class = registry[class_name]
else:
returned_class = import_module_from_string(class_name)
returned_class = None

try:
if class_name in registry:
returned_class = registry[class_name]
else:
returned_class = import_module_from_string(class_name)
except ValueError as exc:
raise ValueError(f"Error fetching class: {class_name}") from exc

return returned_class

Expand Down Expand Up @@ -84,7 +90,7 @@ def import_module_from_string(module_path: str) -> type:
return returned_cls


def fetch_classifier_class(classifier_name: str) -> type:
def fetch_classifier_class(classifier_name: str) -> ResspectClassifier:
"""Fetch the classifier class from the registry.
Parameters
Expand All @@ -95,8 +101,8 @@ def fetch_classifier_class(classifier_name: str) -> type:
Returns
-------
type
The classifier class.
ResspectClassifier
The subclass of ResspectClassifier.
Raises
------
Expand All @@ -106,11 +112,29 @@ def fetch_classifier_class(classifier_name: str) -> type:
If no classifier was specified in the runtime configuration.
"""

clf_class = None
return get_or_load_class(classifier_name, CLASSIFIER_REGISTRY)

try:
clf_class = get_or_load_class(classifier_name, CLASSIFIER_REGISTRY)
except ValueError as exc:
raise ValueError(f"Error fetching class: {classifier_name}") from exc

return clf_class
def fetch_query_strategy_class(query_strategy_name: str) -> QueryStrategy:
"""Fetch the query strategy class from the registry.
Parameters
----------
query_strategy_name : str
The name of the query strategy class to retrieve. This should either be the
name of the class or the import specification for the class.
Returns
-------
QueryStrategy
The subclass of QueryStrategy.
Raises
------
ValueError
If a built-in query strategy was requested, but not found in the registry.
ValueError
If no query strategy was specified in the runtime configuration.
"""

return get_or_load_class(query_strategy_name, QUERY_STRATEGY_REGISTRY)
Loading

0 comments on commit 477015a

Please sign in to comment.