Skip to content

Commit 30d48a5

Browse files
talmoclaude
andcommitted
Fix device mismatch in crop_bboxes for cross-device inference
Ensure bboxes are moved to the same device as images before computing crop indices. This fixes RuntimeError when bboxes are on GPU but images are on CPU (or vice versa) during top-down inference. Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 460f6b7 commit 30d48a5

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

sleap_nn/inference/peak_finding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ def crop_bboxes(
7474
# Get crop centers from bboxes.
7575
# The bbox top-left is at index 0, with (x, y) coordinates.
7676
# We need the center of the crop (peak location), which is top-left + half_size.
77-
crop_x = (bboxes[:, 0, 0] + half_w).to(torch.long)
78-
crop_y = (bboxes[:, 0, 1] + half_h).to(torch.long)
77+
# Ensure bboxes are on the same device as images for index computation.
78+
bboxes_on_device = bboxes.to(device)
79+
crop_x = (bboxes_on_device[:, 0, 0] + half_w).to(torch.long)
80+
crop_y = (bboxes_on_device[:, 0, 1] + half_h).to(torch.long)
7981

8082
# Clamp indices to valid bounds to handle edge cases where centroids
8183
# might be at or beyond image boundaries.
@@ -86,7 +88,7 @@ def crop_bboxes(
8688
# Convert sample_inds to tensor if it's a list.
8789
if not isinstance(sample_inds, torch.Tensor):
8890
sample_inds = torch.tensor(sample_inds, device=device)
89-
sample_inds_long = sample_inds.to(torch.long)
91+
sample_inds_long = sample_inds.to(device=device, dtype=torch.long)
9092
crops = patches[sample_inds_long, :, crop_y, crop_x]
9193
# Shape: (n_crops, channels, height, width)
9294

0 commit comments

Comments
 (0)