Skip to content

Commit 088b696

Browse files
gitttt-1234claude
andauthored
Add Skia-based augmentation backend for faster data pipeline (#431)
## Summary - **Replaces Kornia entirely** with Skia-based augmentation + pure PyTorch operations - Provides **~1.58x faster** CPU augmentation throughput - Reduces GPU transfer bandwidth by **4x** (images stay uint8 until GPU normalization) - **Removes kornia as a dependency** (one fewer external dependency) ## Changes ### New Files - **`sleap_nn/data/skia_augmentation.py`**: Skia implementations of intensity/geometric augmentations and crop_and_resize ### Modified Files - **`sleap_nn/data/augmentation.py`**: Simplified wrapper delegating to Skia - **`sleap_nn/data/custom_datasets.py`**: Use Skia crop_and_resize - **`sleap_nn/data/instance_cropping.py`**: Use Skia crop_and_resize - **`sleap_nn/inference/peak_finding.py`**: Pure PyTorch morphological_dilation (replaces kornia) - **`sleap_nn/training/lightning_modules.py`**: Add `normalize_on_gpu()` to all training/validation steps - **`pyproject.toml`**: Remove kornia, add skia-python>=87.0 ## Architecture ``` Data Pipeline (CPU): Images stay uint8 → Skia augmentation → uint8 output Training/Inference (GPU): uint8 tensor → normalize_on_gpu() → float32 [0,1] → model ``` This deferred normalization reduces PCIe bandwidth by 4x. ## Benchmark Results (from investigation) | Dataset | Skia | Kornia | Speedup | |---------|------|--------|---------| | SingleInstance | 417.5/s | 359.8/s | **1.16x** | | Centroid | 455.5/s | 365.3/s | **1.25x** | | CenteredInstance | 998.0/s | 364.4/s | **2.74x** | | BottomUp | 297.9/s | 253.6/s | **1.18x** | | **Average** | | | **1.58x** | ## Test Results - [x] All data tests pass (52/52) - [x] All training tests pass (9/9) - [x] Peak finding tests pass (9/9) - [ ] End-to-end training verification 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent 0ed5d38 commit 088b696

File tree

8 files changed

+527
-248
lines changed

8 files changed

+527
-248
lines changed

.github/workflows/ci.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ jobs:
8282
if: matrix.os != 'self-hosted-gpu'
8383
run: uv sync --extra torch-cpu
8484

85+
- name: Install graphics dependencies (Ubuntu)
86+
if: matrix.os == 'ubuntu'
87+
run: |
88+
sudo apt-get update && sudo apt-get install -y libglapi-mesa libegl-mesa0 libegl1 libopengl0 libgl1 libglx-mesa0
89+
8590
- name: Print environment info
8691
run: |
8792
echo "=== UV Environment ==="
@@ -103,7 +108,7 @@ jobs:
103108
print('CUDA is not available')
104109
" || echo "CUDA check failed"
105110
echo "=== Import Test ==="
106-
uv run --frozen --extra torch-cpu python -c "import torch; import lightning; import kornia; print('All imports successful')" || echo "Import test failed"
111+
uv run --frozen --extra torch-cpu python -c "import torch; import lightning; import skia; print('All imports successful')" || echo "Import test failed"
107112
108113
- name: Check MPS backend (macOS only)
109114
if: runner.os == 'macOS'
@@ -126,7 +131,7 @@ jobs:
126131
- name: Run pytest
127132
run: |
128133
echo "=== Final environment check before tests ==="
129-
uv run --frozen --extra torch-cpu python -c "import numpy, torch, lightning, kornia; print(f'All packages available: numpy={numpy.__version__}, torch={torch.__version__}')"
134+
uv run --frozen --extra torch-cpu python -c "import numpy, torch, lightning, skia; print(f'All packages available: numpy={numpy.__version__}, torch={torch.__version__}')"
130135
echo "=== Running pytest ==="
131136
uv run --frozen --extra torch-cpu pytest --cov=sleap_nn --cov-report=xml --durations=-1 tests/
132137

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"sleap-io>=0.6.2,<0.7.0",
3333
"numpy",
3434
"lightning",
35-
"kornia",
35+
"skia-python>=87.0",
3636
"jsonpickle",
3737
"scipy",
3838
"attrs",

sleap_nn/data/augmentation.py

Lines changed: 50 additions & 241 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
"""This module implements data pipeline blocks for augmentation operations."""
1+
"""This module implements data pipeline blocks for augmentation operations.
22
3-
from typing import Any, Dict, Optional, Tuple, Union
4-
import kornia as K
3+
Uses Skia (skia-python) for ~1.5x faster augmentation compared to Kornia.
4+
"""
5+
6+
from typing import Optional, Tuple
57
import torch
6-
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
7-
from kornia.augmentation.container import AugmentationSequential
8-
from kornia.augmentation.utils.param_validation import _range_bound
9-
from kornia.core import Tensor
8+
9+
from sleap_nn.data.skia_augmentation import (
10+
apply_intensity_augmentation_skia,
11+
apply_geometric_augmentation_skia,
12+
)
1013

1114

1215
def apply_intensity_augmentation(
@@ -24,8 +27,8 @@ def apply_intensity_augmentation(
2427
brightness_min: Optional[float] = 1.0,
2528
brightness_max: Optional[float] = 1.0,
2629
brightness_p: float = 0.0,
27-
) -> Tuple[torch.Tensor]:
28-
"""Apply kornia intensity augmentation on image and instances.
30+
) -> Tuple[torch.Tensor, torch.Tensor]:
31+
"""Apply intensity augmentation on image and instances.
2932
3033
Args:
3134
image: Input image. Shape: (n_samples, C, H, W)
@@ -46,66 +49,23 @@ def apply_intensity_augmentation(
4649
Returns:
4750
Returns tuple: (image, instances) with augmentation applied.
4851
"""
49-
aug_stack = []
50-
if uniform_noise_p > 0:
51-
aug_stack.append(
52-
RandomUniformNoise(
53-
noise=(uniform_noise_min, uniform_noise_max),
54-
p=uniform_noise_p,
55-
keepdim=True,
56-
same_on_batch=True,
57-
)
58-
)
59-
if gaussian_noise_p > 0:
60-
aug_stack.append(
61-
K.augmentation.RandomGaussianNoise(
62-
mean=gaussian_noise_mean,
63-
std=gaussian_noise_std,
64-
p=gaussian_noise_p,
65-
keepdim=True,
66-
same_on_batch=True,
67-
)
68-
)
69-
if contrast_p > 0:
70-
aug_stack.append(
71-
K.augmentation.RandomContrast(
72-
contrast=(contrast_min, contrast_max),
73-
p=contrast_p,
74-
keepdim=True,
75-
same_on_batch=True,
76-
)
77-
)
78-
if brightness_p > 0:
79-
aug_stack.append(
80-
K.augmentation.RandomBrightness(
81-
brightness=(brightness_min, brightness_max),
82-
p=brightness_p,
83-
keepdim=True,
84-
same_on_batch=True,
85-
)
86-
)
87-
88-
augmenter = AugmentationSequential(
89-
*aug_stack,
90-
data_keys=["input", "keypoints"],
91-
keepdim=True,
92-
same_on_batch=True,
52+
return apply_intensity_augmentation_skia(
53+
image=image,
54+
instances=instances,
55+
uniform_noise_min=uniform_noise_min,
56+
uniform_noise_max=uniform_noise_max,
57+
uniform_noise_p=uniform_noise_p,
58+
gaussian_noise_mean=gaussian_noise_mean,
59+
gaussian_noise_std=gaussian_noise_std,
60+
gaussian_noise_p=gaussian_noise_p,
61+
contrast_min=contrast_min,
62+
contrast_max=contrast_max,
63+
contrast_p=contrast_p,
64+
brightness_min=brightness_min,
65+
brightness_max=brightness_max,
66+
brightness_p=brightness_p,
9367
)
9468

95-
inst_shape = instances.shape
96-
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
97-
# or
98-
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
99-
instances = instances.reshape(inst_shape[0], -1, 2)
100-
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)
101-
102-
aug_image, aug_instances = augmenter(image, instances)
103-
104-
# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
105-
# or
106-
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
107-
return aug_image, aug_instances.reshape(*inst_shape)
108-
10969

11070
def apply_geometric_augmentation(
11171
image: torch.Tensor,
@@ -128,8 +88,8 @@ def apply_geometric_augmentation(
12888
mixup_lambda_min: Optional[float] = 0.01,
12989
mixup_lambda_max: Optional[float] = 0.05,
13090
mixup_p: float = 0.0,
131-
) -> Tuple[torch.Tensor]:
132-
"""Apply kornia geometric augmentation on image and instances.
91+
) -> Tuple[torch.Tensor, torch.Tensor]:
92+
"""Apply geometric augmentation on image and instances.
13393
13494
Args:
13595
image: Input image. Shape: (n_samples, C, H, W)
@@ -160,176 +120,25 @@ def apply_geometric_augmentation(
160120
Returns:
161121
Returns tuple: (image, instances) with augmentation applied.
162122
"""
163-
aug_stack = []
164-
165-
# Check if any individual probability is set
166-
use_independent = (
167-
rotation_p is not None or scale_p is not None or translate_p is not None
123+
return apply_geometric_augmentation_skia(
124+
image=image,
125+
instances=instances,
126+
rotation_min=rotation_min,
127+
rotation_max=rotation_max,
128+
rotation_p=rotation_p,
129+
scale_min=scale_min,
130+
scale_max=scale_max,
131+
scale_p=scale_p,
132+
translate_width=translate_width,
133+
translate_height=translate_height,
134+
translate_p=translate_p,
135+
affine_p=affine_p,
136+
erase_scale_min=erase_scale_min,
137+
erase_scale_max=erase_scale_max,
138+
erase_ratio_min=erase_ratio_min,
139+
erase_ratio_max=erase_ratio_max,
140+
erase_p=erase_p,
141+
mixup_lambda_min=mixup_lambda_min,
142+
mixup_lambda_max=mixup_lambda_max,
143+
mixup_p=mixup_p,
168144
)
169-
170-
if use_independent:
171-
# New behavior: Apply augmentations independently with separate probabilities
172-
if rotation_p is not None and rotation_p > 0:
173-
aug_stack.append(
174-
K.augmentation.RandomRotation(
175-
degrees=(rotation_min, rotation_max),
176-
p=rotation_p,
177-
keepdim=True,
178-
same_on_batch=True,
179-
)
180-
)
181-
182-
if scale_p is not None and scale_p > 0:
183-
aug_stack.append(
184-
K.augmentation.RandomAffine(
185-
degrees=0, # No rotation
186-
translate=None, # No translation
187-
scale=(scale_min, scale_max),
188-
p=scale_p,
189-
keepdim=True,
190-
same_on_batch=True,
191-
)
192-
)
193-
194-
if translate_p is not None and translate_p > 0:
195-
aug_stack.append(
196-
K.augmentation.RandomAffine(
197-
degrees=0, # No rotation
198-
translate=(translate_width, translate_height),
199-
scale=None, # No scaling
200-
p=translate_p,
201-
keepdim=True,
202-
same_on_batch=True,
203-
)
204-
)
205-
elif affine_p > 0:
206-
# Legacy behavior: Bundled affine transformation
207-
aug_stack.append(
208-
K.augmentation.RandomAffine(
209-
degrees=(rotation_min, rotation_max),
210-
translate=(translate_width, translate_height),
211-
scale=(scale_min, scale_max),
212-
p=affine_p,
213-
keepdim=True,
214-
same_on_batch=True,
215-
)
216-
)
217-
218-
if erase_p > 0:
219-
aug_stack.append(
220-
K.augmentation.RandomErasing(
221-
scale=(erase_scale_min, erase_scale_max),
222-
ratio=(erase_ratio_min, erase_ratio_max),
223-
p=erase_p,
224-
keepdim=True,
225-
same_on_batch=True,
226-
)
227-
)
228-
if mixup_p > 0:
229-
aug_stack.append(
230-
K.augmentation.RandomMixUpV2(
231-
lambda_val=(mixup_lambda_min, mixup_lambda_max),
232-
p=mixup_p,
233-
keepdim=True,
234-
same_on_batch=True,
235-
)
236-
)
237-
238-
augmenter = AugmentationSequential(
239-
*aug_stack,
240-
data_keys=["input", "keypoints"],
241-
keepdim=True,
242-
same_on_batch=True,
243-
)
244-
245-
inst_shape = instances.shape
246-
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
247-
# or
248-
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
249-
instances = instances.reshape(inst_shape[0], -1, 2)
250-
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)
251-
252-
aug_image, aug_instances = augmenter(image, instances)
253-
254-
# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
255-
# or
256-
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
257-
return aug_image, aug_instances.reshape(*inst_shape)
258-
259-
260-
class RandomUniformNoise(IntensityAugmentationBase2D):
261-
"""Data transformer for applying random uniform noise to input images.
262-
263-
This is a custom Kornia augmentation inheriting from `IntensityAugmentationBase2D`.
264-
Uniform noise within (min_val, max_val) is applied to the entire input image.
265-
266-
Note: Inverse transform is not implemented and re-applying the same transformation
267-
in the example below does not work when included in an AugmentationSequential class.
268-
269-
Args:
270-
noise: 2-tuple (min_val, max_val); 0.0 <= min_val <= max_val <= 1.0.
271-
p: probability for applying an augmentation. This param controls the augmentation probabilities
272-
element-wise for a batch.
273-
p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
274-
probabilities batch-wise.
275-
same_on_batch: apply the same transformation across the batch.
276-
keepdim: whether to keep the output shape the same as input `True` or broadcast it
277-
to the batch form `False`.
278-
279-
Examples:
280-
>>> rng = torch.manual_seed(0)
281-
>>> img = torch.rand(1, 1, 2, 2)
282-
>>> RandomUniformNoise(min_val=0., max_val=0.1, p=1.)(img)
283-
tensor([[[[0.9607, 0.5865],
284-
[0.2705, 0.5920]]]])
285-
286-
To apply the exact augmentation again, you may take the advantage of the previous parameter state:
287-
>>> input = torch.rand(1, 3, 32, 32)
288-
>>> aug = RandomUniformNoise(min_val=0., max_val=0.1, p=1.)
289-
>>> (aug(input) == aug(input, params=aug._params)).all()
290-
tensor(True)
291-
292-
Ref: `kornia.augmentation._2d.intensity.gaussian_noise
293-
<https://kornia.readthedocs.io/en/latest/_modules/kornia/augmentation/_2d/intensity/gaussian_noise.html#RandomGaussianNoise>`_.
294-
"""
295-
296-
def __init__(
297-
self,
298-
noise: Tuple[float, float],
299-
p: float = 0.5,
300-
p_batch: float = 1.0,
301-
clip_output: bool = True,
302-
same_on_batch: bool = False,
303-
keepdim: bool = False,
304-
) -> None:
305-
"""Initialize the class."""
306-
super().__init__(
307-
p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
308-
)
309-
self.flags = {
310-
"uniform_noise": _range_bound(noise, "uniform_noise", bounds=(0.0, 1.0))
311-
}
312-
self.clip_output = clip_output
313-
314-
def apply_transform(
315-
self,
316-
input: Tensor,
317-
params: Dict[str, Tensor],
318-
flags: Dict[str, Any],
319-
transform: Optional[Tensor] = None,
320-
) -> Tensor:
321-
"""Compute the uniform noise, add, and clamp output."""
322-
if "uniform_noise" in params:
323-
uniform_noise = params["uniform_noise"]
324-
else:
325-
uniform_noise = (
326-
torch.FloatTensor(input.shape)
327-
.uniform_(flags["uniform_noise"][0], flags["uniform_noise"][1])
328-
.to(input.device)
329-
)
330-
self._params["uniform_noise"] = uniform_noise
331-
if self.clip_output:
332-
return torch.clamp(
333-
input + uniform_noise, 0.0, 1.0
334-
) # RandomGaussianNoise doesn't clamp.
335-
return input + uniform_noise

sleap_nn/data/custom_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Custom `torch.utils.data.Dataset`s for different model types."""
22

3-
from kornia.geometry.transform import crop_and_resize
3+
from sleap_nn.data.skia_augmentation import crop_and_resize_skia as crop_and_resize
44

55
# from concurrent.futures import ThreadPoolExecutor # TODO: implement parallel processing
66
# import concurrent.futures

sleap_nn/data/instance_cropping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import sleap_io as sio
77
import torch
8-
from kornia.geometry.transform import crop_and_resize
8+
from sleap_nn.data.skia_augmentation import crop_and_resize_skia as crop_and_resize
99

1010

1111
def compute_augmentation_padding(

0 commit comments

Comments
 (0)