Skip to content

Commit 1c7d6ce

Browse files
committed
OBB based cropping for centered instance
1 parent d946ec3 commit 1c7d6ce

File tree

8 files changed

+418
-80
lines changed

8 files changed

+418
-80
lines changed

sleap_nn/data/augmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def apply_geometric_augmentation(
154154
if affine_p > 0:
155155
aug_stack.append(
156156
K.augmentation.RandomAffine(
157-
degrees=(rotation_min, rotation_max),
157+
degrees=(rotation_min, rotation_min),
158158
translate=(translate_width, translate_height),
159159
scale=(scale_min, scale_max),
160160
p=affine_p,

sleap_nn/data/custom_datasets.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from sleap_nn.config.utils import get_backbone_type_from_cfg, get_model_type_from_cfg
2121
from sleap_nn.data.identity import generate_class_maps, make_class_vectors
2222
from sleap_nn.data.instance_centroids import generate_centroids
23-
from sleap_nn.data.instance_cropping import generate_crops
23+
from sleap_nn.data.instance_cropping import generate_crops, find_instance_crop_size
24+
25+
2426
from sleap_nn.data.normalization import (
2527
apply_normalization,
2628
convert_to_grayscale,
@@ -34,7 +36,7 @@
3436
)
3537
from sleap_nn.data.confidence_maps import generate_confmaps, generate_multiconfmaps
3638
from sleap_nn.data.edge_maps import generate_pafs
37-
from sleap_nn.data.instance_cropping import make_centered_bboxes
39+
from sleap_nn.data.instance_cropping import make_centered_bboxes, get_cropped_img
3840
from sleap_nn.training.utils import is_distributed_initialized
3941
from sleap_nn.config.get_config import get_aug_config
4042

@@ -738,6 +740,7 @@ def __init__(
738740
self.confmap_head_config = confmap_head_config
739741
self.instance_idx_list = self._get_instance_idx_list(labels)
740742
self.cache_lf = [None, None]
743+
# self.max_crop_size = find_instance_crop_size(self.labels, maximum_stride=self.max_stride)
741744

742745
def _get_instance_idx_list(self, labels: List[sio.Labels]) -> List[Tuple[int]]:
743746
"""Return list of tuples with indices of labelled frames and instances."""
@@ -840,24 +843,41 @@ def __getitem__(self, index) -> Dict:
840843
scale=self.scale,
841844
)
842845

843-
# get the centroids based on the anchor idx
844-
centroids = generate_centroids(instances, anchor_ind=self.anchor_ind)
846+
instance = instances[0]
845847

846-
instance, centroid = instances[0], centroids[0] # (n_samples=1)
848+
sample = {}
847849

848-
crop_size = np.array([self.crop_size, self.crop_size]) * np.sqrt(
849-
2
850-
) # crop extra for rotation augmentation
851-
crop_size = crop_size.astype(np.int32).tolist()
850+
# Get the head index
851+
head_idx = self.anchor_ind
852852

853-
sample = generate_crops(image, instance, centroid, crop_size)
853+
# Determine if the instance has enough valid points
854+
valid_points = instance[~torch.isnan(instance).any(dim=1)]
855+
if valid_points.shape[0] < 3:
856+
return self.__getitem__((index + 1) % len(self)) # safely retry next sample
857+
858+
# crop image
859+
sample_image, sample_instance, src_pts, dst_pts, rotated = get_cropped_img(
860+
image[0], instance, head_idx
861+
)
862+
sample_image, sample_instance = sample_image.unsqueeze(
863+
0
864+
), sample_instance.unsqueeze(0)
865+
866+
sample["instance_image"] = sample_image
867+
sample["instance"] = sample_instance
868+
sample["src_pts"] = src_pts.unsqueeze(0)
869+
sample["dst_pts"] = dst_pts.unsqueeze(0)
870+
sample["rotated"] = torch.tensor([rotated], dtype=torch.bool)
854871

855872
sample["frame_idx"] = torch.tensor(lf_frame_idx, dtype=torch.int32)
856873
sample["video_idx"] = torch.tensor(video_idx, dtype=torch.int32)
857874
sample["num_instances"] = num_instances
858875
sample["orig_size"] = torch.Tensor([orig_img_height, orig_img_width]).unsqueeze(
859876
0
860877
)
878+
height, width = sample_image.shape[-2:]
879+
sample["height"] = [height]
880+
sample["width"] = [width]
861881
sample["eff_scale"] = torch.tensor(eff_scale, dtype=torch.float32)
862882

863883
# apply augmentation
@@ -883,27 +903,32 @@ def __getitem__(self, index) -> Dict:
883903
)
884904

885905
# re-crop to original crop size
886-
sample["instance_bbox"] = torch.unsqueeze(
887-
make_centered_bboxes(sample["centroid"][0], self.crop_size, self.crop_size),
888-
0,
889-
) # (n_samples=1, 4, 2)
890-
891-
sample["instance_image"] = crop_and_resize(
906+
# sample["instance_bbox"] = torch.unsqueeze(
907+
# make_centered_bboxes(sample["centroid"][0], self.crop_size, self.crop_size),
908+
# 0,
909+
# ) # (n_samples=1, 4, 2)
910+
911+
# sample["instance_image"] = crop_and_resize(
912+
# sample["instance_image"],
913+
# boxes=sample["instance_bbox"],
914+
# size=(self.crop_size, self.crop_size),
915+
# )
916+
# size matcher
917+
sample_image, eff_scale = apply_sizematcher(
892918
sample["instance_image"],
893-
boxes=sample["instance_bbox"],
894-
size=(self.crop_size, self.crop_size),
919+
max_height=self.crop_size,
920+
max_width=self.crop_size,
895921
)
896-
point = sample["instance_bbox"][0][0]
897-
center_instance = sample["instance"] - point
898-
centered_centroid = sample["centroid"] - point
899-
900-
sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
901-
sample["centroid"] = centered_centroid # (n_samples=1, 2)
922+
# point = sample["instance_bbox"][0][0]
923+
# center_instance = sample["instance"] - point
924+
# centered_centroid = sample["centroid"] - point
902925

903-
# Pad the image (if needed) according max stride
904-
sample["instance_image"] = apply_pad_to_stride(
905-
sample["instance_image"], max_stride=self.max_stride
906-
)
926+
# sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
927+
# sample["centroid"] = centered_centroid # (n_samples=1, 2)
928+
sample_instance = sample["instance"] * eff_scale
929+
sample["instance"] = sample_instance
930+
sample["instance_image"] = sample_image
931+
sample["scale"] = torch.tensor(eff_scale, dtype=torch.float32).unsqueeze(dim=0)
907932

908933
img_hw = sample["instance_image"].shape[-2:]
909934

@@ -1831,7 +1856,12 @@ def get_train_val_datasets(
18311856
),
18321857
scale=config.data_config.preprocessing.scale,
18331858
apply_aug=config.data_config.use_augmentations_train,
1834-
crop_size=config.data_config.preprocessing.crop_size,
1859+
crop_size=find_instance_crop_size(
1860+
train_labels,
1861+
maximum_stride=config.model_config.backbone_config[f"{backbone_type}"][
1862+
"max_stride"
1863+
],
1864+
),
18351865
max_hw=(
18361866
config.data_config.preprocessing.max_height,
18371867
config.data_config.preprocessing.max_width,
@@ -1855,7 +1885,7 @@ def get_train_val_datasets(
18551885
geometric_aug=None,
18561886
scale=config.data_config.preprocessing.scale,
18571887
apply_aug=False,
1858-
crop_size=config.data_config.preprocessing.crop_size,
1888+
crop_size=train_dataset.crop_size,
18591889
max_hw=(
18601890
config.data_config.preprocessing.max_height,
18611891
config.data_config.preprocessing.max_width,

sleap_nn/data/instance_cropping.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,141 @@
66
import sleap_io as sio
77
import torch
88
from kornia.geometry.transform import crop_and_resize
9+
from sleap_nn.data.utils import rotating_calipers
10+
import kornia
11+
12+
13+
def get_cropped_img(image: torch.Tensor, instance: torch.Tensor, head_idx: int):
14+
"""Crop and rotate an image using the oriented bounding box (OBB) of a given instance.
15+
16+
This function performs a padding-aware crop around the instance keypoints using the minimum-area
17+
rotating calipers OBB. It then aligns the longest edge with the x-axis, warps the image and keypoints
18+
accordingly, and applies a conditional 180° rotation if the head is facing left. The output is a
19+
torch-native equivalent of OpenCV's getAffineTransform + warpAffine behavior.
20+
21+
Args:
22+
image (torch.Tensor): A float tensor of shape (C, H, W), representing an RGB image.
23+
instance (torch.Tensor): A float tensor of shape (N, 2), representing keypoint coordinates of one instance.
24+
head_idx (int): Index of the head keypoint, used to determine leftward orientation.
25+
26+
Returns:
27+
cropped_image (torch.Tensor): Cropped and rotated image of shape (C, H, W), aligned to face +x.
28+
adjusted_kpts (torch.Tensor): Keypoints of shape (N, 2), transformed to match the cropped image coordinates.
29+
src_pts (torch.Tensor): Three source points from the padded OBB used for affine transformation (3, 2).
30+
dst_pts (torch.Tensor): Three target points in the crop destination space used for affine warping (3, 2).
31+
rotated (bool): True if the instance was rotated 180° to face the positive x-axis, otherwise False.
32+
"""
33+
# Define padding
34+
pad = 32
35+
36+
# ensure dtype
37+
image = image.float()
38+
device = image.device
39+
instance = instance.to(device)
40+
41+
# Get OBB from keypoints
42+
obb_coords = rotating_calipers(instance)
43+
44+
# Find longest edge and roll OBB
45+
dists = torch.norm(obb_coords - torch.roll(obb_coords, shifts=-1, dims=0), dim=1)
46+
max_index = torch.argmax(dists)
47+
obb_coords = torch.roll(obb_coords, shifts=max_index.item(), dims=0)
48+
49+
# Compute padded OBB by expanding each corner outward from center
50+
center = obb_coords.mean(dim=0, keepdims=True)
51+
vecs = obb_coords - center
52+
norms = torch.norm(vecs, dim=1).unsqueeze(1) # shape: (4, 1)
53+
norms = torch.where(
54+
norms == 0, torch.ones_like(norms), norms
55+
) # avoid division by zero
56+
57+
# Find the OBB edge closest to the x-axis (smallest absolute angle)
58+
best_idx = 0
59+
min_abs_angle = float("inf")
60+
for i in range(4):
61+
edge = obb_coords[(i + 1) % 4] - obb_coords[i]
62+
angle = torch.atan2(edge[1], edge[0])
63+
if abs(angle) < min_abs_angle:
64+
min_abs_angle = abs(angle)
65+
best_idx = i
66+
67+
# Roll so this edge is [0] -> [1]
68+
obb_coords = torch.roll(obb_coords, shifts=-best_idx, dims=0)
69+
edge = obb_coords[1] - obb_coords[0]
70+
angle = torch.atan2(edge[1], edge[0])
71+
72+
# If the edge points left, reverse the OBB
73+
if edge[0] < 0:
74+
obb_coords = obb_coords[::-1]
75+
edge = obb_coords[1] - obb_coords[0]
76+
angle = torch.atan2(edge[1], edge[0])
77+
78+
# Defining the width/height based on the obb coordinates
79+
width = torch.norm(obb_coords[1] - obb_coords[0])
80+
height = torch.norm(obb_coords[3] - obb_coords[0])
81+
82+
# If the crop is taller than wide, rotate OBB by 90 deg to make it horizontal
83+
if height > width:
84+
obb_coords = torch.roll(obb_coords, shifts=-1, dims=0) # rotate OBB 90 degrees
85+
edge = obb_coords[1] - obb_coords[0]
86+
angle = torch.atan2(edge[1], edge[0])
87+
if edge[0] < 0:
88+
obb_coords = torch.flip(obb_coords, dims=[0])
89+
edge = obb_coords[1] - obb_coords[0]
90+
angle = torch.atan2(edge[1], edge[0])
91+
width = torch.norm(obb_coords[1] - obb_coords[0])
92+
height = torch.norm(obb_coords[3] - obb_coords[0])
93+
94+
# Add padding to the final crop dimensions
95+
width += pad * 2
96+
height += pad * 2
97+
98+
# Build affine from OBB -> crop box
99+
src_pts = (
100+
obb_coords[:3].clone().to(dtype=torch.float32, device=device)
101+
) # using corners of OBB
102+
103+
# rectangular region we want to map the OBB onto
104+
dst_pts = torch.tensor(
105+
[[pad, pad], [width - pad, pad], [width - pad, height - pad]],
106+
dtype=torch.float32,
107+
device=device,
108+
)
109+
110+
ones = torch.ones((3, 1), device=device)
111+
src = torch.cat(
112+
[src_pts, ones], dim=1
113+
) # appending 1s to the source points to compute affine transformation
114+
115+
# solves least squares system giving the affine that best maps src_pts -> dst_pts
116+
affine_matrix = torch.linalg.lstsq(src, dst_pts).solution.T
117+
118+
# Warp the image with the affine transform
119+
cropped_image = kornia.geometry.transform.warp_affine(
120+
image.unsqueeze(0), affine_matrix.unsqueeze(0), dsize=(int(height), int(width))
121+
)[0]
122+
123+
# Warp the keypoints with the same affine
124+
kp_homo = torch.cat(
125+
[instance.to(device), torch.ones((instance.shape[0], 1), device=device)], dim=1
126+
)
127+
adjusted_kpts = (affine_matrix @ kp_homo.T).T
128+
129+
# Define head/body keypoints
130+
head_x = adjusted_kpts[head_idx, 0]
131+
body_center_x = adjusted_kpts[:, 0][~torch.isnan(adjusted_kpts[:, 0])].mean()
132+
133+
# Rotate 180° if facing left (by comparing the head keypoint to the body center keypoints)
134+
rotated = False
135+
if head_x < body_center_x:
136+
rotated = True
137+
# Rotate image 180°
138+
cropped_image = torch.rot90(cropped_image, k=2, dims=[1, 2])
139+
140+
adjusted_kpts[:, 0] = cropped_image.shape[2] - adjusted_kpts[:, 0]
141+
adjusted_kpts[:, 1] = cropped_image.shape[1] - adjusted_kpts[:, 1]
142+
143+
return cropped_image, adjusted_kpts, src_pts, dst_pts, rotated
9144

10145

11146
def find_instance_crop_size(

sleap_nn/data/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import psutil
99
import numpy as np
1010
from sleap_nn.data.providers import get_max_instances
11+
from scipy.spatial import ConvexHull
1112

1213

1314
def ensure_list(x: Any) -> List[Any]:
@@ -147,3 +148,59 @@ def check_cache_memory(
147148
if total_cache_memory > available_memory:
148149
return False
149150
return True
151+
152+
153+
def rotating_calipers(points: torch.Tensor):
154+
"""Computes the convex hull of a set of points using the rotating calipers method.
155+
156+
Args:
157+
points (torch.Tensor): (N, 2) tensor of 2D coordinates.
158+
159+
Returns:
160+
torch.Tensor: (4, 2) tensor of the minimum-area bounding box corners.
161+
"""
162+
# Remove NaN values and check if there are enough valid points
163+
valid_points = points[~torch.isnan(points).any(dim=1)]
164+
165+
# Determine the convex hull using scipy's ConvexHull
166+
hull = ConvexHull(valid_points)
167+
hull_points = valid_points[hull.vertices]
168+
169+
min_area = float("inf") # intialize minimum area to infinity
170+
best_box = None # to store the best bounding box found
171+
172+
# Iterate through each edge of the convex hull
173+
for i in range(len(hull_points)):
174+
p1 = hull_points[i]
175+
p2 = hull_points[(i + 1) % len(hull_points)]
176+
177+
# Compute the angle of the edge
178+
edge = p2 - p1
179+
angle = -torch.atan2(edge[1], edge[0])
180+
181+
# Build rotation matrix
182+
cos_a = torch.cos(angle)
183+
sin_a = torch.sin(angle)
184+
R = torch.stack(
185+
[torch.stack([cos_a, -sin_a]), torch.stack([sin_a, cos_a])]
186+
) # shape: (2, 2)
187+
188+
# Rotate points
189+
rotated = (hull_points - p1) @ R.T
190+
191+
# Compute the bounding box of the rotated points
192+
xmin = torch.min(rotated[:, 0])
193+
xmax = torch.max(rotated[:, 0])
194+
ymin = torch.min(rotated[:, 1])
195+
ymax = torch.max(rotated[:, 1])
196+
area = (xmax - xmin) * (ymax - ymin)
197+
198+
# Update the best bounding box if the area is smaller
199+
if area < min_area:
200+
min_area = area
201+
# rectangle corners in rotated coordinates
202+
box = torch.tensor([[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]])
203+
# rotate back to original coordinates
204+
best_box = (box @ R) + p1
205+
206+
return best_box

sleap_nn/inference/predictors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,23 +1189,23 @@ def _make_labeled_frames_from_generator(
11891189
for (
11901190
video_idx,
11911191
frame_idx,
1192-
bbox,
1192+
# bbox,
11931193
pred_instances,
11941194
pred_values,
11951195
instance_score,
11961196
org_size,
11971197
) in zip(
11981198
ex["video_idx"],
11991199
ex["frame_idx"],
1200-
ex["instance_bbox"],
1200+
# ex["instance_bbox"],
12011201
ex["pred_instance_peaks"],
12021202
ex["pred_peak_values"],
12031203
ex["centroid_val"],
12041204
ex["orig_size"],
12051205
):
12061206
if np.isnan(pred_instances).all():
12071207
continue
1208-
pred_instances = pred_instances + bbox.squeeze(axis=0)[0, :]
1208+
# pred_instances = pred_instances + bbox.squeeze(axis=0)[0, :]
12091209
preds[(int(video_idx), int(frame_idx))].append(
12101210
sio.PredictedInstance.from_numpy(
12111211
points_data=pred_instances,

0 commit comments

Comments
 (0)