Skip to content

Commit 0a67e8f

Browse files
authored
Merge branch 'main' into divya/slumbr_base
2 parents 554734b + be15732 commit 0a67e8f

File tree

8 files changed

+115
-46
lines changed

8 files changed

+115
-46
lines changed

.claude/commands/coverage.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ Run tests with coverage.
22

33
Command to run:
44
```
5-
pytest -q --maxfail=1 --cov=sleap_nn --cov-branch tests/ && rm -f .coverage.* && python -m coverage annotate
5+
uv run pytest -q --maxfail=1 --cov --cov-branch && rm .coverage.* && uv run coverage annotate
66
```
77

8-
The result will be the terminal output and the line-by-line coverage will be in files sitting next to each module with the file naming `{module_name.py},cover`.
8+
This generates a coverage annotation file next to each module with the name `{module_name.py},cover`, as well as a simple summary.
99

10-
If you are working on a PR, figure out which files were changed and look for coverage specifically in those. If you don't know which files to look for coverage in, use this:
10+
To get the final actionable summary, run this script:
1111

1212
```
13-
git diff --name-only $(git merge-base origin/main HEAD) | jq -R . | jq -s .
14-
```
13+
uv run python scripts/cov_summary.py --only-pr-diff-lines
14+
```
15+
16+
This will output one module per line with line number ranges for missing coverage. Importantly, it will filter it by diffs in the PR.
17+
18+
Use this summary together with the corresponding `,cover` file to describe each miss to inform subsequent test development.

.claude/commands/pr-description.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ Update PR description.
22

33
Use the `gh` CLI to fetch the current PR description, then update it with a comprehensive description of the changes made in this PR.
44

5+
6+
Command to fetch PR info:
7+
8+
```
9+
gh pr view PR_NUMBER --json
10+
number,title,body,url,state,closingIssuesReferences
11+
```
12+
513
If there is an associated issue (linked in the PR metadata or mentioned in the PR description), then use the `gh` CLI to fetch that too to contextualize the work done in the PR.
614

715
Include a summary, example usage (for enhancements), API changes, and other notes for future consideration (including reasoning behind design decisions).

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,7 @@ wandb/
173173
.serena/
174174

175175
# macOS
176-
.DS_Store
176+
.DS_Store
177+
178+
# Development scratch folder
179+
scratch/

pyproject.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
"Programming Language :: Python :: 3.13",
2323
]
2424
dependencies = [
25-
"sleap-io>=0.2.0",
25+
"sleap-io>=0.5.7",
2626
"numpy",
2727
"lightning",
2828
"kornia",
@@ -54,19 +54,19 @@ readme = {file = ["README.md"], content-type="text/markdown"}
5454
[project.optional-dependencies]
5555
torch = [
5656
"torch",
57-
"torchvision<0.24.0",
57+
"torchvision>=0.20.0,<0.24.0",
5858
]
5959
torch-cpu = [
6060
"torch",
61-
"torchvision<0.24.0",
61+
"torchvision>=0.20.0,<0.24.0",
6262
]
6363
torch-cuda118 = [
6464
"torch",
65-
"torchvision<0.24.0",
65+
"torchvision>=0.20.0,<0.24.0",
6666
]
6767
torch-cuda128 = [
6868
"torch",
69-
"torchvision<0.24.0",
69+
"torchvision>=0.20.0,<0.24.0",
7070
]
7171
dev = [
7272
"pytest",

sleap_nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ def _safe_print(msg):
4848
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
4949
)
5050

51-
__version__ = "0.0.3"
51+
__version__ = "0.0.4"

sleap_nn/data/custom_datasets.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,9 @@ def __getitem__(self, index) -> Dict:
799799

800800
instances = []
801801
for inst in instances_list:
802-
instances.append(inst.numpy())
802+
instances.append(
803+
inst.numpy()
804+
) # no need to filter empty instances; handled while creating instance_idx_list
803805
instances = np.stack(instances, axis=0)
804806

805807
# Add singleton time dimension for single frames.
@@ -1045,8 +1047,9 @@ def __getitem__(self, index) -> Dict:
10451047

10461048
instances = []
10471049
for inst in instances_list:
1048-
if not inst.is_empty:
1049-
instances.append(inst.numpy())
1050+
instances.append(
1051+
inst.numpy()
1052+
) # no need to filter empty instance (handled while creating instance_idx)
10501053
instances = np.stack(instances, axis=0)
10511054

10521055
# Add singleton time dimension for single frames.

sleap_nn/data/instance_cropping.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,16 @@ def find_instance_crop_size(
4444
max_length = 0.0
4545
for lf in labels:
4646
for inst in lf.instances:
47-
pts = inst.numpy()
48-
pts *= input_scaling
49-
diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
50-
diff_x = 0 if np.isnan(diff_x) else diff_x
51-
max_length = np.maximum(max_length, diff_x)
52-
diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
53-
diff_y = 0 if np.isnan(diff_y) else diff_y
54-
max_length = np.maximum(max_length, diff_y)
55-
max_length = np.maximum(max_length, min_crop_size_no_pad)
47+
if not inst.is_empty: # only if at least one point is not nan
48+
pts = inst.numpy()
49+
pts *= input_scaling
50+
diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
51+
diff_x = 0 if np.isnan(diff_x) else diff_x
52+
max_length = np.maximum(max_length, diff_x)
53+
diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
54+
diff_y = 0 if np.isnan(diff_y) else diff_y
55+
max_length = np.maximum(max_length, diff_y)
56+
max_length = np.maximum(max_length, min_crop_size_no_pad)
5657

5758
max_length += float(padding)
5859
crop_size = math.ceil(max_length / float(maximum_stride)) * maximum_stride

sleap_nn/tracking/tracker.py

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
from datetime import datetime
1010
from loguru import logger
1111
import functools
12+
import rich
13+
from rich.progress import (
14+
Progress,
15+
BarColumn,
16+
TimeElapsedColumn,
17+
TimeRemainingColumn,
18+
MofNCompleteColumn,
19+
)
1220

1321
import sleap_io as sio
1422
from sleap_nn.evaluation import compute_oks
@@ -716,6 +724,17 @@ def connect_single_breaks(
716724
return lfs
717725

718726

727+
class RateColumn(rich.progress.ProgressColumn):
728+
"""Renders the progress rate."""
729+
730+
def render(self, task: "Task") -> rich.progress.Text:
731+
"""Show progress rate."""
732+
speed = task.speed
733+
if speed is None:
734+
return rich.progress.Text("?", style="progress.data.speed")
735+
return rich.progress.Text(f"{speed:.1f} frames/s", style="progress.data.speed")
736+
737+
719738
def run_tracker(
720739
untracked_frames: List[sio.LabeledFrame],
721740
window_size: int = 5,
@@ -810,35 +829,66 @@ def run_tracker(
810829
tracking_pre_cull_to_target=tracking_pre_cull_to_target,
811830
tracking_pre_cull_iou_threshold=tracking_pre_cull_iou_threshold,
812831
)
813-
tracked_lfs = []
814-
for lf in untracked_frames:
815-
# prefer user instances over predicted instance
816-
instances = []
817-
if lf.has_user_instances:
818-
instances_to_track = lf.user_instances
819-
if lf.has_predicted_instances:
820-
instances = lf.predicted_instances
821-
else:
822-
instances_to_track = lf.predicted_instances
823832

824-
instances.extend(
825-
tracker.track(
826-
untracked_instances=instances_to_track,
827-
frame_idx=lf.frame_idx,
828-
image=lf.image,
829-
)
830-
)
831-
tracked_lfs.append(
832-
sio.LabeledFrame(
833-
video=lf.video, frame_idx=lf.frame_idx, instances=instances
834-
)
835-
)
833+
try:
834+
with Progress(
835+
"{task.description}",
836+
BarColumn(),
837+
"[progress.percentage]{task.percentage:>3.0f}%",
838+
MofNCompleteColumn(),
839+
"ETA:",
840+
TimeRemainingColumn(),
841+
"Elapsed:",
842+
TimeElapsedColumn(),
843+
RateColumn(),
844+
auto_refresh=False,
845+
refresh_per_second=4,
846+
speed_estimate_period=5,
847+
) as progress:
848+
task = progress.add_task("Tracking...", total=len(untracked_frames))
849+
last_report = time()
850+
851+
tracked_lfs = []
852+
for lf in untracked_frames:
853+
# prefer user instances over predicted instance
854+
instances = []
855+
if lf.has_user_instances:
856+
instances_to_track = lf.user_instances
857+
if lf.has_predicted_instances:
858+
instances = lf.predicted_instances
859+
else:
860+
instances_to_track = lf.predicted_instances
861+
862+
instances.extend(
863+
tracker.track(
864+
untracked_instances=instances_to_track,
865+
frame_idx=lf.frame_idx,
866+
image=lf.image,
867+
)
868+
)
869+
tracked_lfs.append(
870+
sio.LabeledFrame(
871+
video=lf.video, frame_idx=lf.frame_idx, instances=instances
872+
)
873+
)
874+
875+
progress.update(task, advance=1)
876+
877+
if time() - last_report > 0.25:
878+
progress.refresh()
879+
last_report = time()
880+
881+
except KeyboardInterrupt:
882+
logger.info("Tracking interrupted by user")
883+
raise KeyboardInterrupt
836884

837885
if tracking_clean_instance_count > 0:
886+
logger.info("Post-processing: Culling instances...")
838887
tracked_lfs = cull_instances(
839888
tracked_lfs, tracking_clean_instance_count, tracking_clean_iou_threshold
840889
)
841890
if not post_connect_single_breaks:
891+
logger.info("Post-processing: Connecting single breaks...")
842892
tracked_lfs = connect_single_breaks(
843893
tracked_lfs, tracking_clean_instance_count
844894
)

0 commit comments

Comments
 (0)