Skip to content

Commit c584c01

Browse files
gitttt-1234claude
andcommitted
Add progress bar to tracker
- Add rich progress bar to run_tracker function with RateColumn showing frames/sec - Follows same implementation pattern as predictors.py for consistency - Add logging statements for post-processing steps (culling and connecting single breaks) - Add scratch/ folder to .gitignore for development notes - All tests pass locally 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent ae31849 commit c584c01

File tree

4 files changed

+93
-28
lines changed

4 files changed

+93
-28
lines changed

.claude/commands/coverage.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ Run tests with coverage.
22

33
Command to run:
44
```
5-
pytest -q --maxfail=1 --cov=sleap_nn --cov-branch tests/ && rm -f .coverage.* && python -m coverage annotate
5+
uv run pytest -q --maxfail=1 --cov --cov-branch && rm .coverage.* && uv run coverage annotate
66
```
77

8-
The result will be the terminal output and the line-by-line coverage will be in files sitting next to each module with the file naming `{module_name.py},cover`.
8+
This generates a coverage annotation file next to each module with the name `{module_name.py},cover`, as well as a simple summary.
99

10-
If you are working on a PR, figure out which files were changed and look for coverage specifically in those. If you don't know which files to look for coverage in, use this:
10+
To get the final actionable summary, run this script:
1111

1212
```
13-
git diff --name-only $(git merge-base origin/main HEAD) | jq -R . | jq -s .
14-
```
13+
uv run python scripts/cov_summary.py --only-pr-diff-lines
14+
```
15+
16+
This will output one module per line with line number ranges for missing coverage. Importantly, it will filter it by diffs in the PR.
17+
18+
Use this summary together with the corresponding `,cover` file to describe each miss to inform subsequent test development.

.claude/commands/pr-description.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ Update PR description.
22

33
Use the `gh` CLI to fetch the current PR description, then update it with a comprehensive description of the changes made in this PR.
44

5+
6+
Command to fetch PR info:
7+
8+
```
9+
gh pr view PR_NUMBER --json
10+
number,title,body,url,state,closingIssuesReferences
11+
```
12+
513
If there is an associated issue (linked in the PR metadata or mentioned in the PR description), then use the `gh` CLI to fetch that too to contextualize the work done in the PR.
614

715
Include a summary, example usage (for enhancements), API changes, and other notes for future consideration (including reasoning behind design decisions).

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,7 @@ wandb/
173173
.serena/
174174

175175
# macOS
176-
.DS_Store
176+
.DS_Store
177+
178+
# Development scratch folder
179+
scratch/

sleap_nn/tracking/tracker.py

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
from datetime import datetime
1010
from loguru import logger
1111
import functools
12+
import rich
13+
from rich.progress import (
14+
Progress,
15+
BarColumn,
16+
TimeElapsedColumn,
17+
TimeRemainingColumn,
18+
MofNCompleteColumn,
19+
)
1220

1321
import sleap_io as sio
1422
from sleap_nn.evaluation import compute_oks
@@ -712,6 +720,17 @@ def connect_single_breaks(
712720
return lfs
713721

714722

723+
class RateColumn(rich.progress.ProgressColumn):
724+
"""Renders the progress rate."""
725+
726+
def render(self, task: "Task") -> rich.progress.Text:
727+
"""Show progress rate."""
728+
speed = task.speed
729+
if speed is None:
730+
return rich.progress.Text("?", style="progress.data.speed")
731+
return rich.progress.Text(f"{speed:.1f} frames/s", style="progress.data.speed")
732+
733+
715734
def run_tracker(
716735
untracked_frames: List[sio.LabeledFrame],
717736
window_size: int = 5,
@@ -806,35 +825,66 @@ def run_tracker(
806825
tracking_pre_cull_to_target=tracking_pre_cull_to_target,
807826
tracking_pre_cull_iou_threshold=tracking_pre_cull_iou_threshold,
808827
)
809-
tracked_lfs = []
810-
for lf in untracked_frames:
811-
# prefer user instances over predicted instance
812-
instances = []
813-
if lf.has_user_instances:
814-
instances_to_track = lf.user_instances
815-
if lf.has_predicted_instances:
816-
instances = lf.predicted_instances
817-
else:
818-
instances_to_track = lf.predicted_instances
819828

820-
instances.extend(
821-
tracker.track(
822-
untracked_instances=instances_to_track,
823-
frame_idx=lf.frame_idx,
824-
image=lf.image,
825-
)
826-
)
827-
tracked_lfs.append(
828-
sio.LabeledFrame(
829-
video=lf.video, frame_idx=lf.frame_idx, instances=instances
830-
)
831-
)
829+
try:
830+
with Progress(
831+
"{task.description}",
832+
BarColumn(),
833+
"[progress.percentage]{task.percentage:>3.0f}%",
834+
MofNCompleteColumn(),
835+
"ETA:",
836+
TimeRemainingColumn(),
837+
"Elapsed:",
838+
TimeElapsedColumn(),
839+
RateColumn(),
840+
auto_refresh=False,
841+
refresh_per_second=4,
842+
speed_estimate_period=5,
843+
) as progress:
844+
task = progress.add_task("Tracking...", total=len(untracked_frames))
845+
last_report = time()
846+
847+
tracked_lfs = []
848+
for lf in untracked_frames:
849+
# prefer user instances over predicted instance
850+
instances = []
851+
if lf.has_user_instances:
852+
instances_to_track = lf.user_instances
853+
if lf.has_predicted_instances:
854+
instances = lf.predicted_instances
855+
else:
856+
instances_to_track = lf.predicted_instances
857+
858+
instances.extend(
859+
tracker.track(
860+
untracked_instances=instances_to_track,
861+
frame_idx=lf.frame_idx,
862+
image=lf.image,
863+
)
864+
)
865+
tracked_lfs.append(
866+
sio.LabeledFrame(
867+
video=lf.video, frame_idx=lf.frame_idx, instances=instances
868+
)
869+
)
870+
871+
progress.update(task, advance=1)
872+
873+
if time() - last_report > 0.25:
874+
progress.refresh()
875+
last_report = time()
876+
877+
except KeyboardInterrupt:
878+
logger.info("Tracking interrupted by user")
879+
raise KeyboardInterrupt
832880

833881
if tracking_clean_instance_count > 0:
882+
logger.info("Post-processing: Culling instances...")
834883
tracked_lfs = cull_instances(
835884
tracked_lfs, tracking_clean_instance_count, tracking_clean_iou_threshold
836885
)
837886
if not post_connect_single_breaks:
887+
logger.info("Post-processing: Connecting single breaks...")
838888
tracked_lfs = connect_single_breaks(
839889
tracked_lfs, tracking_clean_instance_count
840890
)

0 commit comments

Comments
 (0)