Skip to content

Commit 477015a

Browse files
authored
Initial commit for pluggable query_strategies. (#59)
1 parent 71af541 commit 477015a

File tree

5 files changed

+427
-70
lines changed

5 files changed

+427
-70
lines changed

src/resspect/database.py

Lines changed: 21 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from resspect.query_strategies import *
2828
from resspect.query_budget_strategies import *
2929
from resspect.metrics import get_snpcc_metric
30-
from resspect.plugin_utils import fetch_classifier_class
30+
from resspect.plugin_utils import fetch_classifier_class, fetch_query_strategy_class
3131

3232
__all__ = ['DataBase']
3333

@@ -1265,7 +1265,7 @@ def make_query_budget(self, budgets, strategy='UncSampling', screen=False) -> li
12651265
return query_indx
12661266

12671267
def make_query(self, strategy='UncSampling', batch=1,
1268-
screen=False, queryable=False, query_thre=1.0) -> list:
1268+
screen=False, queryable=False, query_threshold=1.0) -> list:
12691269
"""Identify new object to be added to the training sample.
12701270
12711271
Parameters
@@ -1282,7 +1282,7 @@ def make_query(self, strategy='UncSampling', batch=1,
12821282
queryable: bool (optional)
12831283
If True, consider only queryable objects.
12841284
Default is False.
1285-
query_thre: float (optional)
1285+
query_threshold: float (optional)
12861286
Percentile threshold where a query is considered worth it.
12871287
Default is 1 (no limit).
12881288
screen: bool (optional)
@@ -1305,57 +1305,24 @@ def make_query(self, strategy='UncSampling', batch=1,
13051305

13061306
id_name = self.identify_keywords()
13071307

1308-
if strategy == 'UncSampling':
1309-
query_indx = uncertainty_sampling(class_prob=self.classprob,
1310-
queryable_ids=self.queryable_ids,
1311-
test_ids=self.pool_metadata[id_name].values,
1312-
batch=batch, screen=screen,
1313-
query_thre=query_thre)
1314-
1315-
1316-
elif strategy == 'UncSamplingEntropy':
1317-
query_indx = uncertainty_sampling_entropy(class_prob=self.classprob,
1318-
queryable_ids=self.queryable_ids,
1319-
test_ids=self.pool_metadata[id_name].values,
1320-
batch=batch, screen=screen,
1321-
query_thre=query_thre)
1322-
1323-
elif strategy == 'UncSamplingLeastConfident':
1324-
query_indx = uncertainty_sampling_least_confident(class_prob=self.classprob,
1325-
queryable_ids=self.queryable_ids,
1326-
test_ids=self.pool_metadata[id_name].values,
1327-
batch=batch, screen=screen,
1328-
query_thre=query_thre)
1329-
1330-
elif strategy == 'UncSamplingMargin':
1331-
query_indx = uncertainty_sampling_margin(class_prob=self.classprob,
1332-
queryable_ids=self.queryable_ids,
1333-
test_ids=self.pool_metadata[id_name].values,
1334-
batch=batch, screen=screen,
1335-
query_thre=query_thre)
1336-
return query_indx
1337-
elif strategy == 'QBDMI':
1338-
query_indx = qbd_mi(ensemble_probs=self.ensemble_probs,
1339-
queryable_ids=self.queryable_ids,
1340-
test_ids=self.pool_metadata[id_name].values,
1341-
batch=batch, screen=screen,
1342-
query_thre=query_thre)
1343-
1344-
elif strategy =='QBDEntropy':
1345-
query_indx = qbd_entropy(ensemble_probs=self.ensemble_probs,
1346-
queryable_ids=self.queryable_ids,
1347-
test_ids=self.pool_metadata[id_name].values,
1348-
batch=batch, screen=screen,
1349-
query_thre=query_thre)
1350-
1351-
elif strategy == 'RandomSampling':
1352-
query_indx = random_sampling(queryable_ids=self.queryable_ids,
1353-
test_ids=self.pool_metadata[id_name].values,
1354-
queryable=queryable, batch=batch,
1355-
query_thre=query_thre, screen=screen)
1356-
1357-
else:
1358-
raise ValueError('Invalid strategy.')
1308+
# retrieve and instantiate the query strategy class
1309+
query_strategy_class = fetch_query_strategy_class(strategy)
1310+
query_strategy = query_strategy_class(
1311+
queryable_ids=self.queryable_ids,
1312+
test_ids=self.pool_metadata[id_name].values,
1313+
batch=batch,
1314+
screen=screen,
1315+
query_threshold=query_threshold,
1316+
queryable=queryable
1317+
)
1318+
1319+
# Use the `requires_ensemble` flag to determine which probabilities to use
1320+
input_probabilities = self.classprob
1321+
if query_strategy.requires_ensemble:
1322+
input_probabilities = self.ensemble_probs
1323+
1324+
# get the query index from the strategy
1325+
query_indx = query_strategy.sample(input_probabilities)
13591326

13601327
if screen:
13611328
print(' ... queried obj id: ', self.pool_metadata[id_name].values[query_indx[0]])

src/resspect/plugin_utils.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import importlib
2-
from resspect.classifiers import CLASSIFIER_REGISTRY
2+
from resspect.classifiers import CLASSIFIER_REGISTRY, ResspectClassifier
3+
from resspect.query_strategies import QUERY_STRATEGY_REGISTRY, QueryStrategy
34

45
def get_or_load_class(class_name: str, registry: dict) -> type:
56
"""Given the name of a class and a registry dictionary, attempt to return
@@ -24,10 +25,15 @@ def get_or_load_class(class_name: str, registry: dict) -> type:
2425
`name` key was found in the config.
2526
"""
2627

27-
if class_name in registry:
28-
returned_class = registry[class_name]
29-
else:
30-
returned_class = import_module_from_string(class_name)
28+
returned_class = None
29+
30+
try:
31+
if class_name in registry:
32+
returned_class = registry[class_name]
33+
else:
34+
returned_class = import_module_from_string(class_name)
35+
except ValueError as exc:
36+
raise ValueError(f"Error fetching class: {class_name}") from exc
3137

3238
return returned_class
3339

@@ -84,7 +90,7 @@ def import_module_from_string(module_path: str) -> type:
8490
return returned_cls
8591

8692

87-
def fetch_classifier_class(classifier_name: str) -> type:
93+
def fetch_classifier_class(classifier_name: str) -> ResspectClassifier:
8894
"""Fetch the classifier class from the registry.
8995
9096
Parameters
@@ -95,8 +101,8 @@ def fetch_classifier_class(classifier_name: str) -> type:
95101
96102
Returns
97103
-------
98-
type
99-
The classifier class.
104+
ResspectClassifier
105+
The subclass of ResspectClassifier.
100106
101107
Raises
102108
------
@@ -106,11 +112,29 @@ def fetch_classifier_class(classifier_name: str) -> type:
106112
If no classifier was specified in the runtime configuration.
107113
"""
108114

109-
clf_class = None
115+
return get_or_load_class(classifier_name, CLASSIFIER_REGISTRY)
110116

111-
try:
112-
clf_class = get_or_load_class(classifier_name, CLASSIFIER_REGISTRY)
113-
except ValueError as exc:
114-
raise ValueError(f"Error fetching class: {classifier_name}") from exc
115117

116-
return clf_class
118+
def fetch_query_strategy_class(query_strategy_name: str) -> QueryStrategy:
119+
"""Fetch the query strategy class from the registry.
120+
121+
Parameters
122+
----------
123+
query_strategy_name : str
124+
The name of the query strategy class to retrieve. This should either be the
125+
name of the class or the import specification for the class.
126+
127+
Returns
128+
-------
129+
QueryStrategy
130+
The subclass of QueryStrategy.
131+
132+
Raises
133+
------
134+
ValueError
135+
If a built-in query strategy was requested, but not found in the registry.
136+
ValueError
137+
If no query strategy was specified in the runtime configuration.
138+
"""
139+
140+
return get_or_load_class(query_strategy_name, QUERY_STRATEGY_REGISTRY)

0 commit comments

Comments
 (0)