Skip to content

Commit b495560

Browse files
authored
Minor fixes to inference workflow (#360)
In this PR, the following are updated: This PR sets integral refinement to None by default when running on Apple MPS (unsupported op), replacing the previous CPU fallback that caused slowdowns. The pipeline now stays on MPS unless the user explicitly re-enables refinement. It also fixes a test issue where output path was passed as a directory name instead of a file, triggering a “frame error skipped” condition; tests now pass a valid video path. In this PR, we also implement the suggested fix [here](#357) to gracefully handle unicode encoding errors in the logger. (Related issue: #357)
1 parent b1a2026 commit b495560

File tree

8 files changed

+106
-50
lines changed

8 files changed

+106
-50
lines changed

sleap_nn/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Main module for sleap_nn package."""
22

33
import os
4+
import sys
45
from loguru import logger
56

67
# Get RANK for distributed training
@@ -24,9 +25,24 @@ def _should_log(record):
2425
# Remove default handler and add custom one
2526
logger.remove()
2627

28+
29+
def _safe_print(msg):
30+
"""Print with fallback for encoding errors."""
31+
try:
32+
print(msg, end="")
33+
except UnicodeEncodeError:
34+
# Fallback: replace unencodable characters with '?'
35+
print(
36+
msg.encode(sys.stdout.encoding, errors="replace").decode(
37+
sys.stdout.encoding
38+
),
39+
end="",
40+
)
41+
42+
2743
# Add logger with the custom filter
2844
logger.add(
29-
lambda msg: print(msg, end=""),
45+
_safe_print,
3046
level="DEBUG",
3147
filter=_should_log,
3248
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",

sleap_nn/predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,9 @@ def run_inference(
358358
if integral_refinement is not None and device == "mps": # TODO
359359
# kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
360360
logger.info(
361-
"Integral refinement is not supported with MPS device. Using CPU."
361+
"Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
362362
)
363-
device = "cpu" # not supported with mps
363+
integral_refinement = None
364364

365365
logger.info(f"Using device: {device}")
366366

tests/test_cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import subprocess
88
from sleap_nn.predict import run_inference
99
import sleap_io as sio
10+
import torch
1011

1112

1213
@pytest.fixture
@@ -164,6 +165,8 @@ def test_track_command(
164165
f"{tmp_path}/test.slp",
165166
"--frames",
166167
"0-99",
168+
"--device",
169+
"cpu",
167170
]
168171
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
169172
assert Path(f"{tmp_path}/test.slp").exists()
@@ -182,6 +185,7 @@ def test_eval_command(
182185
make_labels=True,
183186
max_instances=6,
184187
output_path=f"{tmp_path}/test.slp",
188+
device="cpu" if torch.backends.mps.is_available() else "auto",
185189
)
186190
cmd = [
187191
"uv",

tests/test_evaluation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def test_evaluator_main(
504504
make_labels=True,
505505
max_instances=6,
506506
output_path=f"{tmp_path}/test.slp",
507+
device="cpu" if torch.backends.mps.is_available() else "auto",
507508
)
508509

509510
import subprocess

0 commit comments

Comments
 (0)