Skip to content

Commit b81476b

Browse files
gitttt-1234claude
andcommitted
Append video name to output path when video_index is specified
When running inference with video_index parameter to predict on a specific video from a multi-video .slp file, the output path now includes the video name to prevent overwrites. This allows users to run predictions on multiple videos from the same .slp file without manually specifying output paths. Changes: - Modified run_inference() to append video filename stem to output path when video_index is provided and output_path is None - Format: <labels_file>.<video_name>.predictions.slp - Falls back to video_{index} if filename is not a simple string - Added test_video_index_output_path() to verify behavior Example: - Before: labels.slp + video_index=0 → labels.predictions.slp - After: labels.slp + video_index=0 (video1.mp4) → labels.video1.predictions.slp 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 558e0a9 commit b81476b

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

sleap_nn/predict.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,17 @@ def run_inference(
504504

505505
if make_labels:
506506
if output_path is None:
507-
output_path = Path(
508-
data_path if data_path is not None else "results"
509-
).with_suffix(".predictions.slp")
507+
base_path = Path(data_path if data_path is not None else "results")
508+
509+
# If video_index is specified, append video name to output path
510+
if video_index is not None and len(output.videos) > video_index:
511+
video = output.videos[video_index]
512+
# Get video filename and sanitize it for use in path
513+
video_name = Path(video.filename).stem if isinstance(video.filename, str) else f"video_{video_index}"
514+
# Insert video name before .predictions.slp extension
515+
output_path = base_path.parent / f"{base_path.stem}.{video_name}.predictions.slp"
516+
else:
517+
output_path = base_path.with_suffix(".predictions.slp")
510518
output.save(Path(output_path).as_posix(), restore_original_videos=False)
511519
finish_timestamp = str(datetime.now())
512520
logger.info(f"Predictions output path: {output_path}")

tests/test_predict.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,3 +1423,39 @@ def test_topdown_predictor_with_tracking_cleaning(
14231423
device="cpu" if torch.backends.mps.is_available() else "auto",
14241424
)
14251425
assert len(labels) < 10
1426+
1427+
1428+
def test_video_index_output_path(
1429+
minimal_instance,
1430+
minimal_instance_centered_instance_ckpt,
1431+
minimal_instance_centroid_ckpt,
1432+
tmp_path,
1433+
):
1434+
"""Test that video_index appends video name to output path."""
1435+
# Run inference with video_index and no explicit output_path
1436+
labels = run_inference(
1437+
model_paths=[
1438+
minimal_instance_centroid_ckpt,
1439+
minimal_instance_centered_instance_ckpt,
1440+
],
1441+
data_path=minimal_instance.as_posix(),
1442+
video_index=0,
1443+
make_labels=True,
1444+
device="cpu",
1445+
peak_threshold=0.0,
1446+
integral_refinement=None,
1447+
)
1448+
1449+
# Check that output file was created with video name in path
1450+
expected_pattern = f"*minimal_instance*.predictions.slp"
1451+
output_files = list(Path(minimal_instance).parent.glob(expected_pattern))
1452+
assert len(output_files) > 0, "No output file found with video name in path"
1453+
1454+
# The output file should contain the video name
1455+
output_file = output_files[0]
1456+
video_name = Path(labels.videos[0].filename).stem
1457+
assert video_name in output_file.stem, f"Video name '{video_name}' not found in output path '{output_file}'"
1458+
1459+
# Clean up
1460+
for f in output_files:
1461+
f.unlink()

0 commit comments

Comments
 (0)