Skip to content

Commit f9f637f

Browse files
author
Aditya Jain
committed
Working insect order classifier
1 parent ff52f75 commit f9f637f

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

trapdata/api/demo.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .models.classification import (
1111
APIMothClassifier,
12+
InsectOrderClassifier,
1213
MothClassifierBinary,
1314
MothClassifierGlobal,
1415
MothClassifierPanama,
@@ -74,6 +75,12 @@ class ClassifierChoice:
7475
example_images_dir_names=["vermont", "panama", "denmark"],
7576
classifier=MothClassifierGlobal,
7677
),
78+
ClassifierChoice(
79+
key=InsectOrderClassifier.get_key(),
80+
tab_title="Insect Orders",
81+
example_images_dir_names=["vermont", "panama", "denmark"],
82+
classifier=InsectOrderClassifier,
83+
),
7784
]
7885

7986

trapdata/api/models/classification.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from trapdata.ml.models.classification import (
99
GlobalMothSpeciesClassifier,
1010
InferenceBaseClass,
11+
InsectOrderClassifier2025,
1112
MothNonMothClassifier,
1213
PanamaMothSpeciesClassifier2024,
1314
PanamaMothSpeciesClassifierMixedResolution2023,
@@ -188,3 +189,7 @@ class MothClassifierTuringAnguilla(APIMothClassifier, TuringAnguillaSpeciesClass
188189

189190
class MothClassifierGlobal(APIMothClassifier, GlobalMothSpeciesClassifier):
190191
pass
192+
193+
194+
class InsectOrderClassifier(APIMothClassifier, InsectOrderClassifier2025):
195+
pass

trapdata/ml/models/classification.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,14 @@ def forward(self, x):
108108

109109

110110
class ConvNeXtOrderClassifier(InferenceBaseClass):
111-
"""ConvNeXt based insect order classifier"""
111+
"""ConvNeXt based insect order classifier"""
112+
112113
input_size = 128
113114

114115
def get_model(self):
115116
num_classes = len(self.category_map)
116117
model = timm.create_model(
117-
"convnext_tiny.fb_in22k",
118+
"convnext_tiny_in22k",
118119
weights=None,
119120
num_classes=num_classes,
120121
)
@@ -127,7 +128,6 @@ def get_model(self):
127128
model.eval()
128129
return model
129130

130-
131131
def _pad_to_square(self):
132132
"""Padding transformation to make the image square"""
133133

@@ -139,19 +139,17 @@ def _pad_to_square(self):
139139
else:
140140
return torchvision.transforms.Pad(padding=[0, 0, 0, 0])
141141

142-
143142
def get_transforms(self):
144143
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
145144
return torchvision.transforms.Compose(
146-
[
147-
self._pad_to_square(),
145+
[
146+
# self._pad_to_square(),
148147
torchvision.transforms.Resize((self.input_size, self.input_size)),
149148
torchvision.transforms.ToTensor(),
150149
torchvision.transforms.Normalize(mean, std),
151150
]
152151
)
153152

154-
155153
def post_process_batch(self, output):
156154
predictions = torch.nn.functional.softmax(output, dim=1)
157155
predictions = predictions.cpu().numpy()

0 commit comments

Comments
 (0)