Skip to content

Commit a76f6ec

Browse files
authored
Add track cleaning args (#349)
This PR adds track cleaning arguments (pre culling instances, and post tracking clean-up arguments).
1 parent b0179af commit a76f6ec

File tree

8 files changed

+681
-36
lines changed

8 files changed

+681
-36
lines changed

docs/inference.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ sleap-nn track \
107107
| `--video_dataset` | The dataset for HDF5 videos | `None` |
108108
| `--video_input_format` | The input_format for HDF5 videos | `channels_last` |
109109
| `--frames` | List of frames indices. If `None`, all frames in the video are used | All frames |
110+
| `--no_empty_frames` | If `True`, removes frames with no predicted instances from the output labels | `False` |
111+
110112

111113
#### Performance
112114

@@ -177,6 +179,14 @@ When using the `sleap-nn track` CLI command with both `--model_paths` and `--tra
177179
| `--of_max_levels` | Number of pyramid scale levels to consider. This is different from the scale parameter, which determines the initial image scaling (only if `use_flow` is True) | `3` |
178180
| `--post_connect_single_breaks` | If True and `max_tracks` is not None with local queues candidate method, connects track breaks when exactly one track is lost and exactly one new track is spawned in the frame | `False` |
179181

182+
!!! warning "Tracking cleaning and pre-cull parameters"
183+
184+
The parameters `--tracking_pre_cull_to_target`, `--tracking_target_instance_count`, `tracking_pre_cull_iou_threshold`, `tracking_clean_iou_threshold` and `--tracking_clean_instance_count` are provided for backwards compatibility with legacy SLEAP workflows and **may be deprecated in future releases**.
185+
186+
- To restrict the number of instances per frame, use the `--max_instances` parameter, which selects the top instances with the highest prediction scores.
187+
188+
We recommend using `--max_instances` for controlling the number of predicted instances per frame in new projects.
189+
180190
#### Fixed Window Tracking
181191

182192
This method maintains a fixed-size queue with the last N frames and uses all instances from those frames as candidates for matching.

sleap_nn/cli.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,12 @@ def train(config_name, config_dir, overrides):
209209
default=False,
210210
help="Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for initialization during labeling.",
211211
)
212+
@click.option(
213+
"--no_empty_frames",
214+
is_flag=True,
215+
default=False,
216+
help=("Clear empty frames that did not have predictions before saving to output."),
217+
)
212218
@click.option(
213219
"--video_index",
214220
type=int,
@@ -380,6 +386,42 @@ def train(config_name, config_dir, overrides):
380386
default=False,
381387
help="If True and `max_tracks` is not None with local queues candidate method, connects track breaks when exactly one track is lost and exactly one new track is spawned in the frame.",
382388
)
389+
@click.option(
390+
"--tracking_target_instance_count",
391+
type=int,
392+
default=0,
393+
help="Target number of instances to track per frame. (default: 0)",
394+
)
395+
@click.option(
396+
"--tracking_pre_cull_to_target",
397+
type=int,
398+
default=0,
399+
help=(
400+
"If non-zero and target_instance_count is also non-zero, then cull instances "
401+
"over target count per frame *before* tracking. (default: 0)"
402+
),
403+
)
404+
@click.option(
405+
"--tracking_pre_cull_iou_threshold",
406+
type=float,
407+
default=0,
408+
help=(
409+
"If non-zero and pre_cull_to_target also set, then use IOU threshold to remove "
410+
"overlapping instances over count *before* tracking. (default: 0)"
411+
),
412+
)
413+
@click.option(
414+
"--tracking_clean_instance_count",
415+
type=int,
416+
default=0,
417+
help="Target number of instances to clean *after* tracking. (default: 0)",
418+
)
419+
@click.option(
420+
"--tracking_clean_iou_threshold",
421+
type=float,
422+
default=0,
423+
help="IOU to use when culling instances *after* tracking. (default: 0)",
424+
)
383425
def track(**kwargs):
384426
"""Run Inference and Tracking workflow."""
385427
# Convert model_paths from tuple to list

sleap_nn/predict.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
BottomUpMultiClassPredictor,
1010
TopDownMultiClassPredictor,
1111
)
12-
from sleap_nn.tracking.tracker import Tracker, run_tracker, connect_single_breaks
12+
from sleap_nn.tracking.tracker import (
13+
Tracker,
14+
run_tracker,
15+
connect_single_breaks,
16+
cull_instances,
17+
)
1318
from omegaconf import OmegaConf
1419
import sleap_io as sio
1520
from pathlib import Path
@@ -53,6 +58,7 @@ def run_inference(
5358
anchor_part: Optional[str] = None,
5459
only_labeled_frames: bool = False,
5560
only_suggested_frames: bool = False,
61+
no_empty_frames: bool = False,
5662
batch_size: int = 4,
5763
queue_maxsize: int = 8,
5864
video_index: Optional[int] = None,
@@ -92,6 +98,11 @@ def run_inference(
9298
of_window_size: int = 21,
9399
of_max_levels: int = 3,
94100
post_connect_single_breaks: bool = False,
101+
tracking_target_instance_count: int = 0,
102+
tracking_pre_cull_to_target: int = 0,
103+
tracking_pre_cull_iou_threshold: float = 0,
104+
tracking_clean_instance_count: int = 0,
105+
tracking_clean_iou_threshold: float = 0,
95106
):
96107
"""Entry point to run inference on trained SLEAP-NN models.
97108
@@ -125,6 +136,7 @@ def run_inference(
125136
provided, the anchor part in the `training_config.yaml` is used. Default: `None`.
126137
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
127138
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
139+
no_empty_frames: (bool) `True` if empty frames that did not have predictions should be cleared before saving to output. Default: `False`.
128140
batch_size: (int) Number of samples per batch. Default: 4.
129141
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
130142
video_index: (int) Integer index of video in .slp file to predict on. To be used with
@@ -224,6 +236,11 @@ def run_inference(
224236
Default: 3. (only if `use_flow` is True).
225237
post_connect_single_breaks: If True and `max_tracks` is not None with local queues candidate method,
226238
connects track breaks when exactly one track is lost and exactly one new track is spawned in the frame.
239+
tracking_target_instance_count: Target number of instances to track per frame. (default: 0)
240+
tracking_pre_cull_to_target: If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking. (default: 0)
241+
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
242+
tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
243+
tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
227244
228245
Returns:
229246
Returns `sio.Labels` object if `make_labels` is True. Else this function returns
@@ -300,6 +317,11 @@ def run_inference(
300317
of_window_size=of_window_size,
301318
of_max_levels=of_max_levels,
302319
post_connect_single_breaks=post_connect_single_breaks,
320+
tracking_target_instance_count=tracking_target_instance_count,
321+
tracking_pre_cull_to_target=tracking_pre_cull_to_target,
322+
tracking_pre_cull_iou_threshold=tracking_pre_cull_iou_threshold,
323+
tracking_clean_instance_count=tracking_clean_instance_count,
324+
tracking_clean_iou_threshold=tracking_clean_iou_threshold,
303325
)
304326

305327
finish_timestamp = str(datetime.now())
@@ -370,6 +392,9 @@ def run_inference(
370392
of_img_scale=of_img_scale,
371393
of_window_size=of_window_size,
372394
of_max_levels=of_max_levels,
395+
tracking_target_instance_count=tracking_target_instance_count,
396+
tracking_pre_cull_to_target=tracking_pre_cull_to_target,
397+
tracking_pre_cull_iou_threshold=tracking_pre_cull_iou_threshold,
373398
)
374399

375400
if isinstance(predictor, BottomUpPredictor):
@@ -418,29 +443,38 @@ def run_inference(
418443
make_labels=make_labels,
419444
)
420445

421-
if tracking and post_connect_single_breaks:
422-
if max_tracks is None:
423-
max_tracks = max_instances
424-
425-
if max_tracks is None:
426-
message = "Max_tracks (and max instances) is None. To connect single breaks, max_tracks should be set to an integer."
427-
logger.error(message)
428-
raise ValueError(message)
429-
430-
start_final_pass_time = time()
431-
start_fp_timestamp = str(datetime.now())
432-
logger.info(
433-
f"Started final-pass (connecting single breaks) at: {start_fp_timestamp}"
434-
)
435-
corrected_lfs = connect_single_breaks(
436-
lfs=[x for x in output], max_instances=max_tracks
437-
)
438-
finish_fp_timestamp = str(datetime.now())
439-
total_fp_elapsed = time() - start_final_pass_time
440-
logger.info(
441-
f"Finished final-pass (connecting single breaks) at: {finish_fp_timestamp}"
442-
)
443-
logger.info(f"Total runtime: {total_fp_elapsed} secs")
446+
if tracking:
447+
lfs = [x for x in output]
448+
if tracking_clean_instance_count > 0:
449+
lfs = cull_instances(
450+
lfs, tracking_clean_instance_count, tracking_clean_iou_threshold
451+
)
452+
if not post_connect_single_breaks:
453+
corrected_lfs = connect_single_breaks(
454+
lfs, tracking_clean_instance_count
455+
)
456+
elif post_connect_single_breaks:
457+
if not tracking_target_instance_count:
458+
message = "tracking_target_instance_count is 0. To connect single breaks, tracking_target_instance_count should be set to an integer."
459+
logger.error(message)
460+
raise ValueError(message)
461+
462+
start_final_pass_time = time()
463+
start_fp_timestamp = str(datetime.now())
464+
logger.info(
465+
f"Started final-pass (connecting single breaks) at: {start_fp_timestamp}"
466+
)
467+
corrected_lfs = connect_single_breaks(
468+
lfs, max_instances=tracking_target_instance_count
469+
)
470+
finish_fp_timestamp = str(datetime.now())
471+
total_fp_elapsed = time() - start_final_pass_time
472+
logger.info(
473+
f"Finished final-pass (connecting single breaks) at: {finish_fp_timestamp}"
474+
)
475+
logger.info(f"Total runtime: {total_fp_elapsed} secs")
476+
else:
477+
corrected_lfs = lfs
444478

445479
output = sio.Labels(
446480
labeled_frames=corrected_lfs,
@@ -455,6 +489,9 @@ def run_inference(
455489
f"Total runtime: {total_elapsed} secs"
456490
) # TODO: add number of predicted frames
457491

492+
if no_empty_frames:
493+
output.clean(frames=True, skeletons=False)
494+
458495
if make_labels:
459496
if output_path is None:
460497
output_path = Path(

sleap_nn/tracking/tracker.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
compute_euclidean_distance,
2929
compute_iou,
3030
compute_cosine_sim,
31+
cull_instances,
32+
cull_frame_instances,
3133
)
3234

3335

@@ -74,6 +76,9 @@ class Tracker:
7476
robust_best_instance: float = 1.0
7577
use_flow: bool = False
7678
is_local_queue: bool = False
79+
tracking_target_instance_count: int = 0
80+
tracking_pre_cull_to_target: int = 0
81+
tracking_pre_cull_iou_threshold: float = 0
7782
_scoring_functions: Dict[str, Any] = {
7883
"oks": compute_oks,
7984
"iou": compute_iou,
@@ -114,6 +119,9 @@ def from_config(
114119
of_img_scale: float = 1.0,
115120
of_window_size: int = 21,
116121
of_max_levels: int = 3,
122+
tracking_target_instance_count: int = 0,
123+
tracking_pre_cull_to_target: int = 0,
124+
tracking_pre_cull_iou_threshold: float = 0,
117125
):
118126
"""Create `Tracker` from config.
119127
@@ -154,6 +162,9 @@ def from_config(
154162
of_max_levels: Number of pyramid scale levels to consider. This is different
155163
from the scale parameter, which determines the initial image scaling.
156164
Default: 3. (only if `use_flow` is True)
165+
tracking_target_instance_count: Target number of instances to track per frame. (default: 0)
166+
tracking_pre_cull_to_target: If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking. (default: 0)
167+
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
157168
158169
"""
159170
if candidates_method == "fixed_window":
@@ -189,6 +200,9 @@ def from_config(
189200
of_window_size=of_window_size,
190201
of_max_levels=of_max_levels,
191202
is_local_queue=is_local_queue,
203+
tracking_target_instance_count=tracking_target_instance_count,
204+
tracking_pre_cull_to_target=tracking_pre_cull_to_target,
205+
tracking_pre_cull_iou_threshold=tracking_pre_cull_iou_threshold,
192206
)
193207

194208
tracker = cls(
@@ -201,6 +215,9 @@ def from_config(
201215
track_matching_method=track_matching_method,
202216
use_flow=use_flow,
203217
is_local_queue=is_local_queue,
218+
tracking_target_instance_count=tracking_target_instance_count,
219+
tracking_pre_cull_to_target=tracking_pre_cull_to_target,
220+
tracking_pre_cull_iou_threshold=tracking_pre_cull_iou_threshold,
204221
)
205222
return tracker
206223

@@ -220,6 +237,12 @@ def track(
220237
Returns:
221238
List of `sio.PredictedInstance` objects, each having an assigned track.
222239
"""
240+
if self.tracking_target_instance_count and self.tracking_pre_cull_to_target:
241+
untracked_instances = cull_frame_instances(
242+
untracked_instances,
243+
self.tracking_target_instance_count,
244+
self.tracking_pre_cull_iou_threshold,
245+
)
223246
# get features for the untracked instances.
224247
current_instances = self.get_features(untracked_instances, frame_idx, image)
225248

@@ -469,6 +492,9 @@ class FlowShiftTracker(Tracker):
469492
of_max_levels: Number of pyramid scale levels to consider. This is different
470493
from the scale parameter, which determines the initial image scaling.
471494
Default: 3
495+
tracking_target_instance_count: Target number of instances to track per frame. (default: 0)
496+
tracking_pre_cull_to_target: If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking. (default: 0)
497+
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
472498
473499
"""
474500

@@ -703,6 +729,11 @@ def run_tracker(
703729
of_window_size: int = 21,
704730
of_max_levels: int = 3,
705731
post_connect_single_breaks: bool = False,
732+
tracking_target_instance_count: int = 0,
733+
tracking_pre_cull_to_target: int = 0,
734+
tracking_pre_cull_iou_threshold: float = 0,
735+
tracking_clean_instance_count: int = 0,
736+
tracking_clean_iou_threshold: float = 0,
706737
) -> List[sio.LabeledFrame]:
707738
"""Run tracking on a given set of frames.
708739
@@ -746,6 +777,11 @@ def run_tracker(
746777
Default: 3. (only if `use_flow` is True).
747778
post_connect_single_breaks: If True and `max_tracks` is not None with local queues candidate method,
748779
connects track breaks when exactly one track is lost and exactly one new track is spawned in the frame.
780+
tracking_target_instance_count: Target number of instances to track per frame. (default: 0)
781+
tracking_pre_cull_to_target: If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking. (default: 0)
782+
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
783+
tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
784+
tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
749785
750786
Returns:
751787
`sio.Labels` object with tracked instances.
@@ -766,6 +802,9 @@ def run_tracker(
766802
of_img_scale=of_img_scale,
767803
of_window_size=of_window_size,
768804
of_max_levels=of_max_levels,
805+
tracking_target_instance_count=tracking_target_instance_count,
806+
tracking_pre_cull_to_target=tracking_pre_cull_to_target,
807+
tracking_pre_cull_iou_threshold=tracking_pre_cull_iou_threshold,
769808
)
770809
tracked_lfs = []
771810
for lf in untracked_frames:
@@ -791,17 +830,28 @@ def run_tracker(
791830
)
792831
)
793832

833+
if tracking_clean_instance_count > 0:
834+
tracked_lfs = cull_instances(
835+
tracked_lfs, tracking_clean_instance_count, tracking_clean_iou_threshold
836+
)
837+
if not post_connect_single_breaks:
838+
tracked_lfs = connect_single_breaks(
839+
tracked_lfs, tracking_clean_instance_count
840+
)
841+
794842
if post_connect_single_breaks:
795-
if max_tracks is None:
796-
message = "Max_tracks is None. To connect single breaks, max_tracks should be set to an integer."
843+
if not tracking_target_instance_count:
844+
message = "tracking_target_instance_count is 0. To connect single breaks, tracking_target_instance_count should be set to an integer."
797845
logger.error(message)
798846
raise ValueError(message)
799847
start_final_pass_time = time()
800848
start_fp_timestamp = str(datetime.now())
801849
logger.info(
802850
f"Started final-pass (connecting single breaks) at: {start_fp_timestamp}"
803851
)
804-
tracked_lfs = connect_single_breaks(tracked_lfs, max_instances=max_tracks)
852+
tracked_lfs = connect_single_breaks(
853+
tracked_lfs, max_instances=tracking_target_instance_count
854+
)
805855
finish_fp_timestamp = str(datetime.now())
806856
total_fp_elapsed = time() - start_final_pass_time
807857
logger.info(

0 commit comments

Comments
 (0)