Skip to content

Commit ee03f9e

Browse files
authored
Merge pull request #195 from talmolab/divya/close-videos-before-multiprocessing
Close videos before creating data loaders
2 parents 554aa03 + a1ba81a commit ee03f9e

File tree

5 files changed

+15
-4
lines changed

5 files changed

+15
-4
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ dependencies:
99
- python=3.9
1010
- pytorch-cuda=11.8
1111
- numpy
12-
- sleap-io
12+
- sleap-io>=0.2.0
1313
- pydantic
1414
- lightning
1515
- cudnn

environment_cpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ channels:
77
dependencies:
88
- python=3.9
99
- numpy
10-
- sleap-io
10+
- sleap-io>=0.2.0
1111
- pytorch
1212
- pydantic
1313
- lightning

environment_mac.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ channels:
77
dependencies:
88
- python=3.9
99
- numpy
10-
- sleap-io
10+
- sleap-io>=0.2.0
1111
- pydantic
1212
- lightning
1313
- pytorch

sleap_nn/data/custom_datasets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def _fill_cache(self):
161161
self.cache[lf_idx] = img
162162

163163
for video in self.labels.videos:
164-
video.close()
164+
if video.is_open:
165+
video.close()
165166

166167
def _get_video_idx(self, lf):
167168
"""Return indsample of `lf.video` in `labels.videos`."""

sleap_nn/training/model_trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,16 @@ def _create_data_loaders_torch_dataset(self):
475475
else True
476476
)
477477

478+
# If using caching, close the videos to prevent `h5py objects can't be pickled error` when num_workers > 0.
479+
if "cache_img" in self.data_pipeline_fw:
480+
for video in self.train_labels.videos:
481+
if video.is_open:
482+
video.close()
483+
484+
for video in self.val_labels.videos:
485+
if video.is_open:
486+
video.close()
487+
478488
# train
479489
self.train_data_loader = DataLoader(
480490
dataset=self.train_dataset,

0 commit comments

Comments
 (0)