27
27
from resspect .query_strategies import *
28
28
from resspect .query_budget_strategies import *
29
29
from resspect .metrics import get_snpcc_metric
30
- from resspect .plugin_utils import fetch_classifier_class
30
+ from resspect .plugin_utils import fetch_classifier_class , fetch_query_strategy_class
31
31
32
32
__all__ = ['DataBase' ]
33
33
@@ -1265,7 +1265,7 @@ def make_query_budget(self, budgets, strategy='UncSampling', screen=False) -> li
1265
1265
return query_indx
1266
1266
1267
1267
def make_query (self , strategy = 'UncSampling' , batch = 1 ,
1268
- screen = False , queryable = False , query_thre = 1.0 ) -> list :
1268
+ screen = False , queryable = False , query_threshold = 1.0 ) -> list :
1269
1269
"""Identify new object to be added to the training sample.
1270
1270
1271
1271
Parameters
@@ -1282,7 +1282,7 @@ def make_query(self, strategy='UncSampling', batch=1,
1282
1282
queryable: bool (optional)
1283
1283
If True, consider only queryable objects.
1284
1284
Default is False.
1285
- query_thre : float (optional)
1285
+ query_threshold : float (optional)
1286
1286
Percentile threshold where a query is considered worth it.
1287
1287
Default is 1 (no limit).
1288
1288
screen: bool (optional)
@@ -1305,57 +1305,24 @@ def make_query(self, strategy='UncSampling', batch=1,
1305
1305
1306
1306
id_name = self .identify_keywords ()
1307
1307
1308
- if strategy == 'UncSampling' :
1309
- query_indx = uncertainty_sampling (class_prob = self .classprob ,
1310
- queryable_ids = self .queryable_ids ,
1311
- test_ids = self .pool_metadata [id_name ].values ,
1312
- batch = batch , screen = screen ,
1313
- query_thre = query_thre )
1314
-
1315
-
1316
- elif strategy == 'UncSamplingEntropy' :
1317
- query_indx = uncertainty_sampling_entropy (class_prob = self .classprob ,
1318
- queryable_ids = self .queryable_ids ,
1319
- test_ids = self .pool_metadata [id_name ].values ,
1320
- batch = batch , screen = screen ,
1321
- query_thre = query_thre )
1322
-
1323
- elif strategy == 'UncSamplingLeastConfident' :
1324
- query_indx = uncertainty_sampling_least_confident (class_prob = self .classprob ,
1325
- queryable_ids = self .queryable_ids ,
1326
- test_ids = self .pool_metadata [id_name ].values ,
1327
- batch = batch , screen = screen ,
1328
- query_thre = query_thre )
1329
-
1330
- elif strategy == 'UncSamplingMargin' :
1331
- query_indx = uncertainty_sampling_margin (class_prob = self .classprob ,
1332
- queryable_ids = self .queryable_ids ,
1333
- test_ids = self .pool_metadata [id_name ].values ,
1334
- batch = batch , screen = screen ,
1335
- query_thre = query_thre )
1336
- return query_indx
1337
- elif strategy == 'QBDMI' :
1338
- query_indx = qbd_mi (ensemble_probs = self .ensemble_probs ,
1339
- queryable_ids = self .queryable_ids ,
1340
- test_ids = self .pool_metadata [id_name ].values ,
1341
- batch = batch , screen = screen ,
1342
- query_thre = query_thre )
1343
-
1344
- elif strategy == 'QBDEntropy' :
1345
- query_indx = qbd_entropy (ensemble_probs = self .ensemble_probs ,
1346
- queryable_ids = self .queryable_ids ,
1347
- test_ids = self .pool_metadata [id_name ].values ,
1348
- batch = batch , screen = screen ,
1349
- query_thre = query_thre )
1350
-
1351
- elif strategy == 'RandomSampling' :
1352
- query_indx = random_sampling (queryable_ids = self .queryable_ids ,
1353
- test_ids = self .pool_metadata [id_name ].values ,
1354
- queryable = queryable , batch = batch ,
1355
- query_thre = query_thre , screen = screen )
1356
-
1357
- else :
1358
- raise ValueError ('Invalid strategy.' )
1308
+ # retrieve and instantiate the query strategy class
1309
+ query_strategy_class = fetch_query_strategy_class (strategy )
1310
+ query_strategy = query_strategy_class (
1311
+ queryable_ids = self .queryable_ids ,
1312
+ test_ids = self .pool_metadata [id_name ].values ,
1313
+ batch = batch ,
1314
+ screen = screen ,
1315
+ query_threshold = query_threshold ,
1316
+ queryable = queryable
1317
+ )
1318
+
1319
+ # Use the `requires_ensemble` flag to determine which probabilities to use
1320
+ input_probabilities = self .classprob
1321
+ if query_strategy .requires_ensemble :
1322
+ input_probabilities = self .ensemble_probs
1323
+
1324
+ # get the query index from the strategy
1325
+ query_indx = query_strategy .sample (input_probabilities )
1359
1326
1360
1327
if screen :
1361
1328
print (' ... queried obj id: ' , self .pool_metadata [id_name ].values [query_indx [0 ]])
0 commit comments