Skip to content

Commit b1a2026

Browse files
authored
Fix multiprocessing bug with num_workers>0 (#359)
Previously we added caching to support num_workers > 0, since native Labels objects (HDF5 via h5py) aren’t picklable under spawn (macOS/Windows). However, the Labels object was still attached to the dataset and got sent to worker processes, breaking multiprocessing. This PR removes `sio.Labels` from the dataset state when caching is enabled (it’s kept only when caching is disabled), so workers no longer receive a non-picklable handle. Now, with caching on, users can safely set num_workers > 0 and get faster training on macOS/Windows without HDF5 pickling errors.
1 parent 6ed617b commit b1a2026

15 files changed

+351
-162
lines changed

sleap_nn/data/custom_datasets.py

Lines changed: 158 additions & 122 deletions
Large diffs are not rendered by default.

sleap_nn/data/providers.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""This module implements pipeline blocks for reading input data such as labels."""
22

3-
from typing import Any, Dict, Optional, Tuple
3+
from typing import Any, Dict, List, Optional, Tuple
44

55
import numpy as np
66
import sleap_io as sio
@@ -36,15 +36,19 @@ def get_max_height_width(labels: sio.Labels) -> Tuple[int, int]:
3636

3737

3838
def process_lf(
39-
lf: sio.LabeledFrame,
39+
instances_list: List[sio.Instance],
40+
img: np.ndarray,
41+
frame_idx: int,
4042
video_idx: int,
4143
max_instances: int,
4244
user_instances_only: bool = True,
4345
) -> Dict[str, Any]:
4446
"""Get sample dict from `sio.LabeledFrame`.
4547
4648
Args:
47-
lf: Input `sio.LabeledFrame`.
49+
instances_list: List of `sio.Instance` objects.
50+
img: Input image.
51+
frame_idx: Frame index of the given lf.
4852
video_idx: Video index of the given lf.
4953
max_instances: Maximum number of instances that could occur in a single LabeledFrame.
5054
user_instances_only: True if filter labels only to user instances else False.
@@ -57,13 +61,14 @@ def process_lf(
5761
"""
5862
# Filter to user instances
5963
if user_instances_only:
60-
if lf.user_instances is not None and len(lf.user_instances) > 0:
61-
lf.instances = lf.user_instances
64+
user_instances = [inst for inst in instances_list if type(inst) is sio.Instance]
65+
if len(user_instances) > 0:
66+
instances_list = user_instances
6267

63-
image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW
68+
image = np.transpose(img, (2, 0, 1)) # HWC -> CHW
6469

6570
instances = []
66-
for inst in lf:
71+
for inst in instances_list:
6772
if not inst.is_empty:
6873
instances.append(inst.numpy())
6974
instances = np.stack(instances, axis=0)
@@ -92,7 +97,7 @@ def process_lf(
9297
"image": torch.from_numpy(image.copy()),
9398
"instances": instances,
9499
"video_idx": torch.tensor(video_idx, dtype=torch.int32),
95-
"frame_idx": torch.tensor(lf.frame_idx, dtype=torch.int32),
100+
"frame_idx": torch.tensor(frame_idx, dtype=torch.int32),
96101
"orig_size": torch.Tensor([img_height, img_width]).unsqueeze(0),
97102
"num_instances": num_instances,
98103
}

tests/data/test_augmentation.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@ def test_apply_intensity_augmentation(minimal_instance):
1313
"""Test `apply_intensity_augmentation` function."""
1414
labels = sio.load_slp(minimal_instance)
1515
lf = labels[0]
16-
ex = process_lf(lf, 0, 2)
16+
ex = process_lf(
17+
instances_list=lf.instances,
18+
img=lf.image,
19+
frame_idx=lf.frame_idx,
20+
video_idx=0,
21+
max_instances=2,
22+
)
1723
ex["image"] = apply_normalization(ex["image"])
1824

1925
img, pts = apply_intensity_augmentation(
@@ -36,7 +42,13 @@ def test_apply_geometric_augmentation(minimal_instance):
3642
"""Test `apply_geometric_augmentation` function."""
3743
labels = sio.load_slp(minimal_instance)
3844
lf = labels[0]
39-
ex = process_lf(lf, 0, 2)
45+
ex = process_lf(
46+
instances_list=lf.instances,
47+
img=lf.image,
48+
frame_idx=lf.frame_idx,
49+
video_idx=0,
50+
max_instances=2,
51+
)
4052
ex["image"] = apply_normalization(ex["image"])
4153

4254
img, pts = apply_geometric_augmentation(

tests/data/test_confmaps.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@ def test_generate_confmaps(minimal_instance):
1515
"""Test `generate_confmaps` function."""
1616
labels = sio.load_slp(minimal_instance)
1717
lf = labels[0]
18-
ex = process_lf(lf, 0, 2)
18+
ex = process_lf(
19+
instances_list=lf.instances,
20+
img=lf.image,
21+
frame_idx=lf.frame_idx,
22+
video_idx=0,
23+
max_instances=2,
24+
)
1925

2026
confmaps = generate_confmaps(
2127
ex["instances"][:, 0].unsqueeze(dim=1), img_hw=(384, 384)
@@ -27,7 +33,13 @@ def test_generate_multiconfmaps(minimal_instance):
2733
"""Test `generate_multiconfmaps` function."""
2834
labels = sio.load_slp(minimal_instance)
2935
lf = labels[0]
30-
ex = process_lf(lf, 0, 2)
36+
ex = process_lf(
37+
instances_list=lf.instances,
38+
img=lf.image,
39+
frame_idx=lf.frame_idx,
40+
video_idx=0,
41+
max_instances=2,
42+
)
3143

3244
confmaps = generate_multiconfmaps(
3345
ex["instances"], img_hw=(384, 384), num_instances=ex["num_instances"]

tests/data/test_custom_datasets.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_bottomup_dataset(minimal_instance, tmp_path):
8888
cache_img="memory",
8989
apply_aug=base_bottom_config.use_augmentations_train,
9090
)
91-
dataset._fill_cache()
91+
dataset._fill_cache([sio.load_slp(minimal_instance)])
9292

9393
gt_sample_keys = [
9494
"image",
@@ -201,7 +201,7 @@ def test_bottomup_dataset(minimal_instance, tmp_path):
201201
cache_img="disk",
202202
cache_img_path=f"{tmp_path}/cache_imgs",
203203
)
204-
dataset._fill_cache()
204+
dataset._fill_cache([sio.load_slp(minimal_instance)])
205205

206206
gt_sample_keys = [
207207
"image",
@@ -314,7 +314,7 @@ def test_bottomup_multiclass_dataset(minimal_instance, tmp_path):
314314
cache_img="memory",
315315
apply_aug=base_bottom_config.use_augmentations_train,
316316
)
317-
dataset._fill_cache()
317+
dataset._fill_cache([tracked_labels])
318318

319319
sample = next(iter(dataset))
320320
assert len(sample.keys()) == len(gt_sample_keys)
@@ -400,7 +400,7 @@ def test_bottomup_multiclass_dataset(minimal_instance, tmp_path):
400400
cache_img="disk",
401401
cache_img_path=f"{tmp_path}/cache_imgs",
402402
)
403-
dataset._fill_cache()
403+
dataset._fill_cache([tracked_labels])
404404

405405
sample = next(iter(dataset))
406406
assert len(sample.keys()) == len(gt_sample_keys)
@@ -446,7 +446,13 @@ def test_centered_instance_dataset(minimal_instance, tmp_path):
446446
cache_img="disk",
447447
cache_img_path=f"{tmp_path}/cache_imgs",
448448
)
449-
dataset._fill_cache()
449+
dataset._fill_cache(
450+
[
451+
sio.load_slp(minimal_instance),
452+
sio.load_slp(minimal_instance),
453+
sio.load_slp(minimal_instance),
454+
]
455+
)
450456

451457
gt_sample_keys = [
452458
"centroid",
@@ -481,7 +487,7 @@ def test_centered_instance_dataset(minimal_instance, tmp_path):
481487
cache_img="memory",
482488
apply_aug=base_topdown_data_config.use_augmentations_train,
483489
)
484-
dataset._fill_cache()
490+
dataset._fill_cache([sio.load_slp(minimal_instance)])
485491

486492
gt_sample_keys = [
487493
"centroid",
@@ -711,7 +717,7 @@ def test_centered_multiclass_dataset(minimal_instance, tmp_path):
711717
cache_img="disk",
712718
cache_img_path=f"{tmp_path}/cache_imgs",
713719
)
714-
dataset._fill_cache()
720+
dataset._fill_cache([tracked_labels, tracked_labels, tracked_labels])
715721

716722
gt_sample_keys = [
717723
"centroid",
@@ -749,7 +755,7 @@ def test_centered_multiclass_dataset(minimal_instance, tmp_path):
749755
cache_img="memory",
750756
apply_aug=base_topdown_data_config.use_augmentations_train,
751757
)
752-
dataset._fill_cache()
758+
dataset._fill_cache([tracked_labels])
753759

754760
sample = next(iter(dataset))
755761
assert len(sample.keys()) == len(gt_sample_keys)
@@ -923,7 +929,7 @@ def test_centroid_dataset(minimal_instance, tmp_path):
923929
cache_img="disk",
924930
cache_img_path=f"{tmp_path}/cache_imgs",
925931
)
926-
dataset._fill_cache()
932+
dataset._fill_cache([sio.load_slp(minimal_instance)])
927933

928934
gt_sample_keys = [
929935
"image",
@@ -957,7 +963,7 @@ def test_centroid_dataset(minimal_instance, tmp_path):
957963
apply_aug=base_centroid_data_config.use_augmentations_train,
958964
labels=[sio.load_slp(minimal_instance)],
959965
)
960-
dataset._fill_cache()
966+
dataset._fill_cache([sio.load_slp(minimal_instance)])
961967

962968
gt_sample_keys = [
963969
"image",
@@ -1094,7 +1100,7 @@ def test_single_instance_dataset(minimal_instance, tmp_path):
10941100
cache_img="disk",
10951101
cache_img_path=f"{tmp_path}/cache_imgs",
10961102
)
1097-
dataset._fill_cache()
1103+
dataset._fill_cache([labels, labels, labels])
10981104
sample = next(iter(dataset))
10991105
assert len(dataset) == 3
11001106

@@ -1127,7 +1133,7 @@ def test_single_instance_dataset(minimal_instance, tmp_path):
11271133
cache_img="memory",
11281134
apply_aug=base_singleinstance_data_config.use_augmentations_train,
11291135
)
1130-
dataset._fill_cache()
1136+
dataset._fill_cache([labels])
11311137

11321138
sample = next(iter(dataset))
11331139
assert len(dataset) == 1

tests/data/test_edge_maps.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,13 @@ def test_generate_pafs(minimal_instance):
196196
"""Test `generate_pafs` function."""
197197
labels = sio.load_slp(minimal_instance)
198198
lf = labels[0]
199-
ex = process_lf(lf, 0, 2)
199+
ex = process_lf(
200+
instances_list=lf.instances,
201+
img=lf.image,
202+
frame_idx=lf.frame_idx,
203+
video_idx=0,
204+
max_instances=2,
205+
)
200206

201207
pafs = generate_pafs(
202208
ex["instances"],

tests/data/test_instance_centroids.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ def test_generate_centroids(minimal_instance):
1010
"""Test `generate_centroids` function."""
1111
labels = sio.load_slp(minimal_instance)
1212
lf = labels[0]
13-
ex = process_lf(lf, 0, 2)
13+
ex = process_lf(
14+
instances_list=lf.instances,
15+
img=lf.image,
16+
frame_idx=lf.frame_idx,
17+
video_idx=0,
18+
max_instances=2,
19+
)
1420

1521
centroids = generate_centroids(ex["instances"], 1).int()
1622
gt = torch.Tensor([[[152, 158], [278, 203]]]).int()

tests/data/test_instance_cropping.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def test_generate_crops(minimal_instance):
4444
"""Test `generate_crops` function."""
4545
labels = sio.load_slp(minimal_instance)
4646
lf = labels[0]
47-
ex = process_lf(lf, 0, 2)
47+
ex = process_lf(
48+
instances_list=lf.instances,
49+
img=lf.image,
50+
frame_idx=lf.frame_idx,
51+
video_idx=0,
52+
max_instances=2,
53+
)
4854
ex["image"] = apply_normalization(ex["image"])
4955

5056
centroids = generate_centroids(ex["instances"], 0)

tests/data/test_providers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,13 @@ def test_labelsreader_provider(minimal_instance):
250250
def test_process_lf(minimal_instance):
251251
labels = sio.load_slp(minimal_instance)
252252
lf = labels[0]
253-
ex = process_lf(lf, 0, 4)
253+
ex = process_lf(
254+
instances_list=lf.instances,
255+
img=lf.image,
256+
frame_idx=lf.frame_idx,
257+
video_idx=0,
258+
max_instances=4,
259+
)
254260

255261
assert ex["image"].shape == torch.Size([1, 1, 384, 384])
256262
assert ex["instances"].shape == torch.Size([1, 4, 2, 2])

tests/data/test_resizing.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@ def test_apply_resizer(minimal_instance):
3030
"""Test `apply_resizer` function."""
3131
labels = sio.load_slp(minimal_instance)
3232
lf = labels[0]
33-
ex = process_lf(lf, 0, 2)
33+
ex = process_lf(
34+
instances_list=lf.instances,
35+
img=lf.image,
36+
frame_idx=lf.frame_idx,
37+
video_idx=0,
38+
max_instances=2,
39+
)
3440

3541
image, instances = apply_resizer(ex["image"], ex["instances"], scale=2.0)
3642
assert image.shape == torch.Size([1, 1, 768, 768])
@@ -41,7 +47,13 @@ def test_apply_pad_to_stride(minimal_instance):
4147
"""Test `apply_pad_to_stride` function."""
4248
labels = sio.load_slp(minimal_instance)
4349
lf = labels[0]
44-
ex = process_lf(lf, 0, 2)
50+
ex = process_lf(
51+
instances_list=lf.instances,
52+
img=lf.image,
53+
frame_idx=lf.frame_idx,
54+
video_idx=0,
55+
max_instances=2,
56+
)
4557

4658
image = apply_pad_to_stride(ex["image"], max_stride=2)
4759
assert image.shape == torch.Size([1, 1, 384, 384])
@@ -54,7 +66,13 @@ def test_apply_sizematcher(caplog, minimal_instance):
5466
"""Test `apply_sizematcher` function."""
5567
labels = sio.load_slp(minimal_instance)
5668
lf = labels[0]
57-
ex = process_lf(lf, 0, 2)
69+
ex = process_lf(
70+
instances_list=lf.instances,
71+
img=lf.image,
72+
frame_idx=lf.frame_idx,
73+
video_idx=0,
74+
max_instances=2,
75+
)
5876

5977
image, _ = apply_sizematcher(ex["image"], 500, 500)
6078
assert image.shape == torch.Size([1, 1, 500, 500])

0 commit comments

Comments
 (0)