Skip to content

Commit 10251d6

Browse files
committed
feat: let each model set the default taxon rank of results
1 parent 6cc3508 commit 10251d6

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

trapdata/api/api.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,14 @@ def should_filter_detections(Classifier: type[APIMothClassifier]) -> bool:
6767

6868
def make_category_map_response(
6969
model: APIMothDetector | APIMothClassifier,
70-
default_taxon_rank: str = "SPECIES",
7170
) -> AlgorithmCategoryMapResponse:
7271
categories_sorted_by_index = sorted(model.category_map.items(), key=lambda x: x[0])
7372
# as list of dicts:
7473
categories_sorted_by_index = [
7574
{
7675
"index": index,
7776
"label": label,
78-
"taxon_rank": default_taxon_rank,
77+
"taxon_rank": model.default_taxon_rank,
7978
}
8079
for index, label in categories_sorted_by_index
8180
]

trapdata/ml/models/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class InferenceBaseClass:
7676
category_map = {}
7777
num_classes: Union[int, None] = None # Will use len(category_map) if None
7878
lookup_gbif_names: bool = False
79+
default_taxon_rank: str = "SPECIES"
7980
model: torch.nn.Module
8081
normalization = tensorflow_normalization
8182
transforms: torchvision.transforms.Compose

trapdata/ml/models/classification.py

+2
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class BinaryClassifier(Resnet50ClassifierLowRes):
293293
type = "binary_classification"
294294
positive_binary_label: str = constants.POSITIVE_BINARY_LABEL
295295
negative_binary_label: str = constants.NEGATIVE_BINARY_LABEL
296+
default_taxon_rank = "SUPERFAMILY"
296297

297298
def get_queue(self) -> DetectedObjectQueue:
298299
return DetectedObjectQueue(self.db_path, self.image_base_path)
@@ -564,3 +565,4 @@ class InsectOrderClassifier2025(SpeciesClassifier, ConvNeXtOrderClassifier):
564565
description = "ConvNeXt-T based insect order classifier for 16 classes trained by Mila in January 2025"
565566
weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/insect_orders/convnext_tiny_in22k_worder0.5_wbinary0.5_run2_checkpoint.pt"
566567
labels_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/insect_orders/insect_order_category_map.json"
568+
default_taxon_rank = "ORDER"

0 commit comments

Comments
 (0)