Skip to content

Commit fe2fa11

Browse files
talmoclaude
andauthored
Replace kornia crop_and_resize with fast tensor indexing (#426)
## Summary - Replaces kornia's `crop_and_resize` in `peak_finding.py` with a faster implementation using tensor unfold operations - Achieves **17-51x speedup** (CUDA/MPS) by avoiding perspective transform computations - Removes the MPS special case that disabled integral refinement on Mac (no longer needed) - Adds tests for edge cases (boundary peaks, empty inputs, multiple samples) ## Performance | Platform | kornia | simple indexing | Speedup | |----------|--------|-----------------|---------| | MPS (M-series Mac) | 21.45 ms | 0.42 ms | **51x** | | CUDA (RTX A6000) | 2.64 ms | 0.15 ms | **17x** | The new approach: - Uses `F.pad` + `unfold` to create patch views (no memory copy) - Selects patches via advanced indexing - Removes dependency on `torch.linalg.solve` (was blocking MPS support in older PyTorch) ## Test plan - [x] Existing `test_peak_finding.py` tests pass - [x] Added `test_crop_bboxes_edge_cases` for boundary peaks - [x] Added `test_crop_bboxes_empty` for empty inputs - [x] Added `test_crop_bboxes_multiple_samples` for multi-sample batches - [ ] Run full inference on real data to verify end-to-end correctness 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent df1f160 commit fe2fa11

File tree

5 files changed

+168
-25
lines changed

5 files changed

+168
-25
lines changed

sleap_nn/inference/peak_finding.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
from typing import Optional, Tuple
44

55
import kornia as K
6-
import numpy as np
76
import torch
8-
from kornia.geometry.transform import crop_and_resize
7+
import torch.nn.functional as F
98

109
from sleap_nn.data.instance_cropping import make_centered_bboxes
1110

1211

1312
def crop_bboxes(
1413
images: torch.Tensor, bboxes: torch.Tensor, sample_inds: torch.Tensor
1514
) -> torch.Tensor:
16-
"""Crop bounding boxes from a batch of images.
15+
"""Crop bounding boxes from a batch of images using fast tensor indexing.
16+
17+
This uses tensor unfold operations to extract patches, which is significantly
18+
faster than kornia's crop_and_resize (17-51x speedup) as it avoids perspective
19+
transform computations.
1720
1821
Args:
1922
images: Tensor of shape (samples, channels, height, width) of a batch of images.
@@ -27,7 +30,7 @@ def crop_bboxes(
2730
box should be cropped from.
2831
2932
Returns:
30-
A tensor of shape (n_bboxes, crop_height, crop_width, channels) of the same
33+
A tensor of shape (n_bboxes, channels, crop_height, crop_width) of the same
3134
dtype as the input image. The crop size is inferred from the bounding box
3235
coordinates.
3336
@@ -42,26 +45,51 @@ def crop_bboxes(
4245
4346
See also: `make_centered_bboxes`
4447
"""
48+
n_crops = bboxes.shape[0]
49+
if n_crops == 0:
50+
# Return empty tensor; use default crop size since we can't infer from bboxes
51+
return torch.empty(
52+
0, images.shape[1], 0, 0, device=images.device, dtype=images.dtype
53+
)
54+
4555
# Compute bounding box size to use for crops.
46-
height = abs(bboxes[0, 3, 1] - bboxes[0, 0, 1])
47-
width = abs(bboxes[0, 1, 0] - bboxes[0, 0, 0])
48-
box_size = tuple(torch.round(torch.Tensor((height + 1, width + 1))).to(torch.int32))
56+
height = int(abs(bboxes[0, 3, 1] - bboxes[0, 0, 1]).item()) + 1
57+
width = int(abs(bboxes[0, 1, 0] - bboxes[0, 0, 0]).item()) + 1
4958

5059
# Store original dtype for conversion back after cropping.
5160
original_dtype = images.dtype
61+
device = images.device
62+
n_samples, channels, img_h, img_w = images.shape
63+
half_h, half_w = height // 2, width // 2
5264

53-
# Kornia's crop_and_resize requires float32 input.
54-
images_to_crop = images[sample_inds]
55-
if not torch.is_floating_point(images_to_crop):
56-
images_to_crop = images_to_crop.float()
57-
58-
# Crop.
59-
crops = crop_and_resize(
60-
images_to_crop, # (n_boxes, channels, height, width)
61-
boxes=bboxes,
62-
size=box_size,
65+
# Pad images for edge handling.
66+
images_padded = F.pad(
67+
images.float(), (half_w, half_w, half_h, half_h), mode="constant", value=0
6368
)
6469

70+
# Extract all possible patches using unfold (creates a view, no copy).
71+
# Shape after unfold: (n_samples, channels, img_h, img_w, height, width)
72+
patches = images_padded.unfold(2, height, 1).unfold(3, width, 1)
73+
74+
# Get crop centers from bboxes.
75+
# The bbox top-left is at index 0, with (x, y) coordinates.
76+
# We need the center of the crop (peak location), which is top-left + half_size.
77+
crop_x = (bboxes[:, 0, 0] + half_w).to(torch.long)
78+
crop_y = (bboxes[:, 0, 1] + half_h).to(torch.long)
79+
80+
# Clamp indices to valid bounds to handle edge cases where centroids
81+
# might be at or beyond image boundaries.
82+
crop_x = torch.clamp(crop_x, 0, patches.shape[3] - 1)
83+
crop_y = torch.clamp(crop_y, 0, patches.shape[2] - 1)
84+
85+
# Select crops using advanced indexing.
86+
# Convert sample_inds to tensor if it's a list.
87+
if not isinstance(sample_inds, torch.Tensor):
88+
sample_inds = torch.tensor(sample_inds, device=device)
89+
sample_inds_long = sample_inds.to(torch.long)
90+
crops = patches[sample_inds_long, :, crop_y, crop_x]
91+
# Shape: (n_crops, channels, height, width)
92+
6593
# Cast back to original dtype and return.
6694
crops = crops.to(original_dtype)
6795
return crops

sleap_nn/predict.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -448,13 +448,6 @@ def run_inference(
448448
else "mps" if torch.backends.mps.is_available() else "cpu"
449449
)
450450

451-
if integral_refinement is not None and device == "mps": # TODO
452-
# kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
453-
logger.info(
454-
"Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
455-
)
456-
integral_refinement = None
457-
458451
logger.info(f"Using device: {device}")
459452

460453
# initializes the inference model

tests/inference/test_peak_finding.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
find_local_peaks,
99
find_local_peaks_rough,
1010
)
11+
from sleap_nn.data.instance_cropping import make_centered_bboxes
1112

1213

1314
def test_crop_bboxes(minimal_bboxes, minimal_cms):
@@ -28,6 +29,84 @@ def test_crop_bboxes(minimal_bboxes, minimal_cms):
2829
assert cm_crops.dtype == torch.float32
2930

3031

32+
def test_crop_bboxes_edge_cases():
33+
"""Test crop_bboxes with edge cases like peaks near image boundaries."""
34+
# Create a test image with peaks at various positions including edges
35+
img = torch.zeros(1, 1, 20, 20)
36+
37+
# Peak at center
38+
img[0, 0, 10, 10] = 1.0
39+
40+
# Peak at corner (0, 0)
41+
img[0, 0, 0, 0] = 0.8
42+
43+
# Peak at edge
44+
img[0, 0, 0, 10] = 0.9
45+
46+
# Create bboxes for these peaks
47+
points = torch.tensor(
48+
[
49+
[10.0, 10.0], # center
50+
[0.0, 0.0], # corner
51+
[10.0, 0.0], # edge
52+
]
53+
)
54+
bboxes = make_centered_bboxes(points, box_height=5, box_width=5)
55+
sample_inds = torch.tensor([0, 0, 0])
56+
57+
crops = crop_bboxes(img, bboxes, sample_inds)
58+
59+
assert crops.shape == (3, 1, 5, 5)
60+
61+
# Center crop should have the peak at center
62+
assert crops[0, 0, 2, 2] == 1.0
63+
64+
# Corner crop should have the peak at center (with zero padding)
65+
assert crops[1, 0, 2, 2] == 0.8
66+
67+
# Edge crop should have the peak at center
68+
assert crops[2, 0, 2, 2] == 0.9
69+
70+
71+
def test_crop_bboxes_empty():
72+
"""Test crop_bboxes with empty bboxes."""
73+
img = torch.zeros(1, 1, 20, 20)
74+
bboxes = torch.empty(0, 4, 2)
75+
sample_inds = torch.empty(0, dtype=torch.long)
76+
77+
crops = crop_bboxes(img, bboxes, sample_inds)
78+
79+
# Should return empty tensor
80+
assert crops.shape[0] == 0
81+
assert crops.shape[1] == 1 # Preserves channel dimension
82+
83+
84+
def test_crop_bboxes_multiple_samples():
85+
"""Test crop_bboxes with multiple samples."""
86+
# Create 3 samples with different peak locations
87+
imgs = torch.zeros(3, 1, 20, 20)
88+
imgs[0, 0, 5, 5] = 1.0
89+
imgs[1, 0, 10, 10] = 2.0
90+
imgs[2, 0, 15, 15] = 3.0
91+
92+
points = torch.tensor(
93+
[
94+
[5.0, 5.0],
95+
[10.0, 10.0],
96+
[15.0, 15.0],
97+
]
98+
)
99+
bboxes = make_centered_bboxes(points, box_height=5, box_width=5)
100+
sample_inds = torch.tensor([0, 1, 2])
101+
102+
crops = crop_bboxes(imgs, bboxes, sample_inds)
103+
104+
assert crops.shape == (3, 1, 5, 5)
105+
assert crops[0, 0, 2, 2] == 1.0
106+
assert crops[1, 0, 2, 2] == 2.0
107+
assert crops[2, 0, 2, 2] == 3.0
108+
109+
31110
def test_integral_regression(minimal_bboxes, minimal_cms):
32111
cms = torch.load(minimal_cms).unsqueeze(0)
33112
bboxes = torch.load(minimal_bboxes)

tests/training/test_lightning_modules.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Test TrainingModule classes."""
22

33
import numpy as np
4+
import os
45
from pathlib import Path
56
from omegaconf import OmegaConf
7+
import wandb
68
from sleap_nn.data.custom_datasets import (
79
get_train_val_dataloaders,
810
get_train_val_datasets,
@@ -39,6 +41,31 @@ def caplog(caplog: LogCaptureFixture):
3941
logger.remove(handler_id)
4042

4143

44+
@pytest.fixture(autouse=True)
45+
def cleanup_wandb():
46+
"""Ensure wandb runs in offline mode and is cleaned up after each test.
47+
48+
This fixture:
49+
1. Sets WANDB_MODE=offline to prevent network hangs on CI
50+
2. Cleans up any active wandb run after the test to prevent state leakage
51+
"""
52+
# Save original mode and force offline to prevent network hangs on CI
53+
original_mode = os.environ.get("WANDB_MODE")
54+
os.environ["WANDB_MODE"] = "offline"
55+
56+
yield
57+
58+
# Finish any active wandb run to prevent contamination between tests
59+
if wandb.run is not None:
60+
wandb.finish()
61+
62+
# Restore original WANDB_MODE
63+
if original_mode is not None:
64+
os.environ["WANDB_MODE"] = original_mode
65+
else:
66+
os.environ.pop("WANDB_MODE", None)
67+
68+
4269
def test_topdown_centered_instance_model(
4370
config, tmp_path: str, minimal_instance_centered_instance_ckpt
4471
):

tests/training/test_model_trainer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,28 @@ def caplog(caplog: LogCaptureFixture):
5353

5454
@pytest.fixture(autouse=True)
5555
def cleanup_wandb():
56-
"""Ensure wandb run is finished after each test to prevent state leakage."""
56+
"""Ensure wandb runs in offline mode and is cleaned up after each test.
57+
58+
This fixture:
59+
1. Sets WANDB_MODE=offline to prevent network hangs on CI
60+
2. Cleans up any active wandb run after the test to prevent state leakage
61+
"""
62+
# Save original mode and force offline to prevent network hangs on CI
63+
original_mode = os.environ.get("WANDB_MODE")
64+
os.environ["WANDB_MODE"] = "offline"
65+
5766
yield
67+
5868
# Finish any active wandb run to prevent contamination between tests
5969
if wandb.run is not None:
6070
wandb.finish()
6171

72+
# Restore original WANDB_MODE
73+
if original_mode is not None:
74+
os.environ["WANDB_MODE"] = original_mode
75+
else:
76+
os.environ.pop("WANDB_MODE", None)
77+
6278

6379
def test_cfg_without_val_labels_path(config, tmp_path, minimal_instance):
6480
"""Test Model Trainer if no val labels path is provided."""

0 commit comments

Comments
 (0)