|
9 | 9 | from datetime import datetime |
10 | 10 | from loguru import logger |
11 | 11 | import functools |
| 12 | +import rich |
| 13 | +from rich.progress import ( |
| 14 | + Progress, |
| 15 | + BarColumn, |
| 16 | + TimeElapsedColumn, |
| 17 | + TimeRemainingColumn, |
| 18 | + MofNCompleteColumn, |
| 19 | +) |
12 | 20 |
|
13 | 21 | import sleap_io as sio |
14 | 22 | from sleap_nn.evaluation import compute_oks |
@@ -716,6 +724,17 @@ def connect_single_breaks( |
716 | 724 | return lfs |
717 | 725 |
|
718 | 726 |
|
| 727 | +class RateColumn(rich.progress.ProgressColumn): |
| 728 | + """Renders the progress rate.""" |
| 729 | + |
| 730 | + def render(self, task: "Task") -> rich.progress.Text: |
| 731 | + """Show progress rate.""" |
| 732 | + speed = task.speed |
| 733 | + if speed is None: |
| 734 | + return rich.progress.Text("?", style="progress.data.speed") |
| 735 | + return rich.progress.Text(f"{speed:.1f} frames/s", style="progress.data.speed") |
| 736 | + |
| 737 | + |
719 | 738 | def run_tracker( |
720 | 739 | untracked_frames: List[sio.LabeledFrame], |
721 | 740 | window_size: int = 5, |
@@ -810,35 +829,66 @@ def run_tracker( |
810 | 829 | tracking_pre_cull_to_target=tracking_pre_cull_to_target, |
811 | 830 | tracking_pre_cull_iou_threshold=tracking_pre_cull_iou_threshold, |
812 | 831 | ) |
813 | | - tracked_lfs = [] |
814 | | - for lf in untracked_frames: |
815 | | - # prefer user instances over predicted instance |
816 | | - instances = [] |
817 | | - if lf.has_user_instances: |
818 | | - instances_to_track = lf.user_instances |
819 | | - if lf.has_predicted_instances: |
820 | | - instances = lf.predicted_instances |
821 | | - else: |
822 | | - instances_to_track = lf.predicted_instances |
823 | 832 |
|
824 | | - instances.extend( |
825 | | - tracker.track( |
826 | | - untracked_instances=instances_to_track, |
827 | | - frame_idx=lf.frame_idx, |
828 | | - image=lf.image, |
829 | | - ) |
830 | | - ) |
831 | | - tracked_lfs.append( |
832 | | - sio.LabeledFrame( |
833 | | - video=lf.video, frame_idx=lf.frame_idx, instances=instances |
834 | | - ) |
835 | | - ) |
| 833 | + try: |
| 834 | + with Progress( |
| 835 | + "{task.description}", |
| 836 | + BarColumn(), |
| 837 | + "[progress.percentage]{task.percentage:>3.0f}%", |
| 838 | + MofNCompleteColumn(), |
| 839 | + "ETA:", |
| 840 | + TimeRemainingColumn(), |
| 841 | + "Elapsed:", |
| 842 | + TimeElapsedColumn(), |
| 843 | + RateColumn(), |
| 844 | + auto_refresh=False, |
| 845 | + refresh_per_second=4, |
| 846 | + speed_estimate_period=5, |
| 847 | + ) as progress: |
| 848 | + task = progress.add_task("Tracking...", total=len(untracked_frames)) |
| 849 | + last_report = time() |
| 850 | + |
| 851 | + tracked_lfs = [] |
| 852 | + for lf in untracked_frames: |
| 853 | + # prefer user instances over predicted instance |
| 854 | + instances = [] |
| 855 | + if lf.has_user_instances: |
| 856 | + instances_to_track = lf.user_instances |
| 857 | + if lf.has_predicted_instances: |
| 858 | + instances = lf.predicted_instances |
| 859 | + else: |
| 860 | + instances_to_track = lf.predicted_instances |
| 861 | + |
| 862 | + instances.extend( |
| 863 | + tracker.track( |
| 864 | + untracked_instances=instances_to_track, |
| 865 | + frame_idx=lf.frame_idx, |
| 866 | + image=lf.image, |
| 867 | + ) |
| 868 | + ) |
| 869 | + tracked_lfs.append( |
| 870 | + sio.LabeledFrame( |
| 871 | + video=lf.video, frame_idx=lf.frame_idx, instances=instances |
| 872 | + ) |
| 873 | + ) |
| 874 | + |
| 875 | + progress.update(task, advance=1) |
| 876 | + |
| 877 | + if time() - last_report > 0.25: |
| 878 | + progress.refresh() |
| 879 | + last_report = time() |
| 880 | + |
| 881 | + except KeyboardInterrupt: |
| 882 | + logger.info("Tracking interrupted by user") |
| 883 | + raise KeyboardInterrupt |
836 | 884 |
|
837 | 885 | if tracking_clean_instance_count > 0: |
| 886 | + logger.info("Post-processing: Culling instances...") |
838 | 887 | tracked_lfs = cull_instances( |
839 | 888 | tracked_lfs, tracking_clean_instance_count, tracking_clean_iou_threshold |
840 | 889 | ) |
841 | 890 | if not post_connect_single_breaks: |
| 891 | + logger.info("Post-processing: Connecting single breaks...") |
842 | 892 | tracked_lfs = connect_single_breaks( |
843 | 893 | tracked_lfs, tracking_clean_instance_count |
844 | 894 | ) |
|
0 commit comments