Skip to content

Commit 4dbc7cd

Browse files
authored
Update bottomup inference and add note on num_workers (#361)
This PR updates the docs by adding note on using num_workers for mac and windows. In this PR, we also add a minor fix to speed-up bottomup model inference.
1 parent b495560 commit 4dbc7cd

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

docs/config.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ The config file has three main sections:
1818
- **`train_labels_path`**: Path(s) to your training label files.
1919
- **`val_labels_path`**: Path(s) to your validation label files.
2020
- **`augmentation_config`**: Controls data augmentation settings for training.
21+
- **`data_pipeline_fw`**: Method to load data during training. Options: [`torch_dataset`, `torch_dataset_cache_img_memory`, `torch_dataset_cache_img_disk`].
2122

2223
- **`model_config`**
2324
- **`head_configs`**: Defines the output heads (e.g., for confidence maps, part affinity fields, etc.).
2425

2526
- **`trainer_config`**
2627
- **`ckpt_dir`**: Directory where checkpoints and logs will be saved.
2728
- **`run_name`**: Name for this training run (used for organizing outputs and logging). The checkpoints for a specific run would be saved in `<ckpt_dir>/<run_name>` folder.
29+
- **`train_data_loader.num_workers`** and **`val_data_loader.num_workers`**: Number of workers for dataloading. (For mac and windows, set this to > 0, ONLY if caching is used for `data_config.data_pipeline_fw`.)
2830

2931
### Sample configuration format
3032

sleap_nn/inference/paf_grouping.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,11 @@ def match_candidates_sample(
545545
546546
See also: match_candidates_batch
547547
"""
548+
# Move tensors to CPU once to avoid repeated device<->host synchronizations
549+
edge_inds_sample = edge_inds_sample.detach().cpu()
550+
edge_peak_inds_sample = edge_peak_inds_sample.detach().cpu()
551+
line_scores_sample = line_scores_sample.detach().cpu()
552+
548553
match_edge_inds = []
549554
match_src_peak_inds = []
550555
match_dst_peak_inds = []
@@ -572,6 +577,8 @@ def match_candidates_sample(
572577
edge_peak_inds_k[:, 1] == dst_ind
573578
)
574579
if mask.any():
580+
# `line_scores_k` is already on CPU; `.item()` does not trigger
581+
# a device synchronization and matches the original behaviour.
575582
cost_matrix[i, j] = -line_scores_k[
576583
mask
577584
].item() # Flip sign for maximization.

0 commit comments

Comments
 (0)