Skip to content

Commit 6ed617b

Browse files
authored
Fix target instance count parameter (#358)
Currently, enabling `post_connect_single_breaks` forces users to set `tracking_target_instance_count`, even when `max_instances` is already provided. In this PR we fix this by ensuring when post_connect_single_breaks=True and tracking_target_instance_count is None, we infer the target from max_instances; if tracking_target_instance_count is explicitly set, it still takes precedence.
1 parent ae31849 commit 6ed617b

File tree

5 files changed

+85
-23
lines changed

5 files changed

+85
-23
lines changed

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,19 @@ readme = {file = ["README.md"], content-type="text/markdown"}
5454
[project.optional-dependencies]
5555
torch = [
5656
"torch",
57-
"torchvision",
57+
"torchvision<0.24.0",
5858
]
5959
torch-cpu = [
6060
"torch",
61-
"torchvision",
61+
"torchvision<0.24.0",
6262
]
6363
torch-cuda118 = [
6464
"torch",
65-
"torchvision",
65+
"torchvision<0.24.0",
6666
]
6767
torch-cuda128 = [
6868
"torch",
69-
"torchvision",
69+
"torchvision<0.24.0",
7070
]
7171
dev = [
7272
"pytest",

sleap_nn/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def train(config_name, config_dir, overrides):
389389
@click.option(
390390
"--tracking_target_instance_count",
391391
type=int,
392-
default=0,
392+
default=None,
393393
help="Target number of instances to track per frame. (default: 0)",
394394
)
395395
@click.option(

sleap_nn/predict.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_inference(
9898
of_window_size: int = 21,
9999
of_max_levels: int = 3,
100100
post_connect_single_breaks: bool = False,
101-
tracking_target_instance_count: int = 0,
101+
tracking_target_instance_count: Optional[int] = None,
102102
tracking_pre_cull_to_target: int = 0,
103103
tracking_pre_cull_iou_threshold: float = 0,
104104
tracking_clean_instance_count: int = 0,
@@ -236,7 +236,7 @@ def run_inference(
236236
Default: 3. (only if `use_flow` is True).
237237
post_connect_single_breaks: If True and `max_tracks` is not None with local queues candidate method,
238238
connects track breaks when exactly one track is lost and exactly one new track is spawned in the frame.
239-
tracking_target_instance_count: Target number of instances to track per frame. (default: 0)
239+
tracking_target_instance_count: Target number of instances to track per frame. (default: None)
240240
tracking_pre_cull_to_target: If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking. (default: 0)
241241
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
242242
tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
@@ -300,6 +300,14 @@ def run_inference(
300300

301301
logger.info(f"Running tracking on {len(lf_frames)} frames...")
302302

303+
if post_connect_single_breaks or tracking_pre_cull_to_target:
304+
if tracking_target_instance_count is None and max_instances is None:
305+
message = "Both tracking_target_instance_count and max_instances is set to 0. To connect single breaks or pre-cull to target, at least one of them should be set to an integer."
306+
logger.error(message)
307+
raise ValueError(message)
308+
elif tracking_target_instance_count is None:
309+
tracking_target_instance_count = max_instances
310+
303311
tracked_frames = run_tracker(
304312
untracked_frames=lf_frames,
305313
window_size=tracking_window_size,
@@ -377,6 +385,13 @@ def run_inference(
377385
and not isinstance(predictor, BottomUpMultiClassPredictor)
378386
and not isinstance(predictor, TopDownMultiClassPredictor)
379387
):
388+
if post_connect_single_breaks or tracking_pre_cull_to_target:
389+
if tracking_target_instance_count is None and max_instances is None:
390+
message = "Both tracking_target_instance_count and max_instances is set to 0. To connect single breaks or pre-cull to target, at least one of them should be set to an integer."
391+
logger.error(message)
392+
raise ValueError(message)
393+
elif tracking_target_instance_count is None:
394+
tracking_target_instance_count = max_instances
380395
predictor.tracker = Tracker.from_config(
381396
candidates_method=candidates_method,
382397
min_match_points=min_match_points,
@@ -454,11 +469,6 @@ def run_inference(
454469
lfs, tracking_clean_instance_count
455470
)
456471
elif post_connect_single_breaks:
457-
if not tracking_target_instance_count:
458-
message = "tracking_target_instance_count is 0. To connect single breaks, tracking_target_instance_count should be set to an integer."
459-
logger.error(message)
460-
raise ValueError(message)
461-
462472
start_final_pass_time = time()
463473
start_fp_timestamp = str(datetime.now())
464474
logger.info(

sleap_nn/tracking/tracker.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class Tracker:
7676
robust_best_instance: float = 1.0
7777
use_flow: bool = False
7878
is_local_queue: bool = False
79-
tracking_target_instance_count: int = 0
79+
tracking_target_instance_count: Optional[int] = None
8080
tracking_pre_cull_to_target: int = 0
8181
tracking_pre_cull_iou_threshold: float = 0
8282
_scoring_functions: Dict[str, Any] = {
@@ -119,7 +119,7 @@ def from_config(
119119
of_img_scale: float = 1.0,
120120
of_window_size: int = 21,
121121
of_max_levels: int = 3,
122-
tracking_target_instance_count: int = 0,
122+
tracking_target_instance_count: Optional[int] = None,
123123
tracking_pre_cull_to_target: int = 0,
124124
tracking_pre_cull_iou_threshold: float = 0,
125125
):
@@ -162,7 +162,7 @@ def from_config(
162162
of_max_levels: Number of pyramid scale levels to consider. This is different
163163
from the scale parameter, which determines the initial image scaling.
164164
Default: 3. (only if `use_flow` is True)
165-
tracking_target_instance_count: Target number of instances to track per frame. (default: 0)
165+
tracking_target_instance_count: Target number of instances to track per frame. (default: None)
166166
tracking_pre_cull_to_target: If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking. (default: 0)
167167
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
168168
@@ -237,7 +237,11 @@ def track(
237237
Returns:
238238
List of `sio.PredictedInstance` objects, each having an assigned track.
239239
"""
240-
if self.tracking_target_instance_count and self.tracking_pre_cull_to_target:
240+
if (
241+
self.tracking_target_instance_count is not None
242+
and self.tracking_target_instance_count
243+
and self.tracking_pre_cull_to_target
244+
):
241245
untracked_instances = cull_frame_instances(
242246
untracked_instances,
243247
self.tracking_target_instance_count,
@@ -492,7 +496,7 @@ class FlowShiftTracker(Tracker):
492496
of_max_levels: Number of pyramid scale levels to consider. This is different
493497
from the scale parameter, which determines the initial image scaling.
494498
Default: 3
495-
tracking_target_instance_count: Target number of instances to track per frame. (default: 0)
499+
tracking_target_instance_count: Target number of instances to track per frame. (default: None)
496500
tracking_pre_cull_to_target: If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking. (default: 0)
497501
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
498502
@@ -729,7 +733,7 @@ def run_tracker(
729733
of_window_size: int = 21,
730734
of_max_levels: int = 3,
731735
post_connect_single_breaks: bool = False,
732-
tracking_target_instance_count: int = 0,
736+
tracking_target_instance_count: Optional[int] = None,
733737
tracking_pre_cull_to_target: int = 0,
734738
tracking_pre_cull_iou_threshold: float = 0,
735739
tracking_clean_instance_count: int = 0,
@@ -777,7 +781,7 @@ def run_tracker(
777781
Default: 3. (only if `use_flow` is True).
778782
post_connect_single_breaks: If True and `max_tracks` is not None with local queues candidate method,
779783
connects track breaks when exactly one track is lost and exactly one new track is spawned in the frame.
780-
tracking_target_instance_count: Target number of instances to track per frame. (default: 0)
784+
tracking_target_instance_count: Target number of instances to track per frame. (default: None)
781785
tracking_pre_cull_to_target: If non-zero and target_instance_count is also non-zero, then cull instances over target count per frame *before* tracking. (default: 0)
782786
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
783787
tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
@@ -840,8 +844,11 @@ def run_tracker(
840844
)
841845

842846
if post_connect_single_breaks:
843-
if not tracking_target_instance_count:
844-
message = "tracking_target_instance_count is 0. To connect single breaks, tracking_target_instance_count should be set to an integer."
847+
if (
848+
tracking_target_instance_count is None
849+
or tracking_target_instance_count == 0
850+
):
851+
message = "tracking_target_instance_count is None or 0. To connect single breaks, tracking_target_instance_count should be set to an integer."
845852
logger.error(message)
846853
raise ValueError(message)
847854
start_final_pass_time = time()

tests/test_predict.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,31 @@ def test_topdown_predictor(
260260

261261
assert len(pred_labels.tracks) <= 2 # should be less than max tracks
262262

263+
for lf in pred_labels:
264+
for instance in lf.instances:
265+
assert instance.track is not None
266+
267+
# test with tracking and tracking_target_instance_count is not provided
268+
pred_labels = run_inference(
269+
model_paths=[
270+
minimal_instance_centroid_ckpt,
271+
minimal_instance_centered_instance_ckpt,
272+
],
273+
data_path=centered_instance_video.as_posix(),
274+
make_labels=True,
275+
output_path=tmp_path,
276+
max_instances=2,
277+
post_connect_single_breaks=True,
278+
max_tracks=None,
279+
device="cpu",
280+
peak_threshold=0.1,
281+
frames=[x for x in range(5)],
282+
tracking=True,
283+
integral_refinement=None,
284+
)
285+
286+
assert len(pred_labels.tracks) <= 2 # should be less than max tracks
287+
263288
for lf in pred_labels:
264289
for instance in lf.instances:
265290
assert instance.track is not None
@@ -276,15 +301,14 @@ def test_topdown_predictor(
276301
output_path=tmp_path,
277302
max_instances=None,
278303
post_connect_single_breaks=True,
279-
tracking_target_instance_count=0,
280304
max_tracks=None,
281305
device="cpu",
282306
peak_threshold=0.1,
283307
frames=[x for x in range(20)],
284308
tracking=True,
285309
integral_refinement=None,
286310
)
287-
assert "tracking_target_instance_count is 0" in caplog.text
311+
assert "tracking_target_instance_count and max_instances" in caplog.text
288312

289313

290314
def test_multiclass_topdown_predictor(
@@ -1050,6 +1074,27 @@ def test_tracking_only_pipeline(
10501074
integral_refinement=None,
10511075
)
10521076

1077+
# neither max_instances nor tracking_target_instance_count is provided
1078+
with pytest.raises(ValueError):
1079+
labels = run_inference(
1080+
data_path=centered_instance_video.as_posix(),
1081+
tracking=True,
1082+
integral_refinement=None,
1083+
post_connect_single_breaks=True,
1084+
)
1085+
1086+
# racking_target_instance_count is provided
1087+
labels = run_inference(
1088+
data_path=minimal_instance.as_posix(),
1089+
tracking=True,
1090+
integral_refinement=None,
1091+
post_connect_single_breaks=True,
1092+
max_instances=2,
1093+
output_path=tmp_path,
1094+
)
1095+
for lf in labels:
1096+
assert len(lf.instances) <= 2
1097+
10531098

10541099
def test_legacy_topdown_predictor(
10551100
minimal_instance,

0 commit comments

Comments
 (0)