|
| 1 | +# omnimcp/tracking.py |
| 2 | +from typing import List, Dict, Optional, Tuple |
| 3 | + |
| 4 | +# Use typing_extensions for Self if needed for older Python versions |
| 5 | +# from typing_extensions import Self |
| 6 | + |
| 7 | +# Added Scipy for matching |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +try: |
| 11 | + from scipy.optimize import linear_sum_assignment |
| 12 | + from scipy.spatial.distance import cdist |
| 13 | + |
| 14 | + SCIPY_AVAILABLE = True |
| 15 | +except ImportError: |
| 16 | + SCIPY_AVAILABLE = False |
| 17 | + # Fallback or warning needed if scipy is critical |
| 18 | + import warnings |
| 19 | + |
| 20 | + warnings.warn( |
| 21 | + "Scipy not found. Tracking matching will be disabled or use a fallback." |
| 22 | + ) |
| 23 | + |
| 24 | + |
| 25 | +# Assuming UIElement and ElementTrack are defined in omnimcp.types |
| 26 | +try: |
| 27 | + from omnimcp.types import UIElement, ElementTrack, Bounds |
| 28 | +except ImportError: |
| 29 | + print("Warning: Could not import types from omnimcp.types") |
| 30 | + UIElement = dict # type: ignore |
| 31 | + ElementTrack = dict # type: ignore |
| 32 | + Bounds = tuple # type: ignore |
| 33 | + |
| 34 | +# Assuming logger is setup elsewhere and accessible, or use standard logging |
| 35 | +# from omnimcp.utils import logger |
| 36 | +import logging |
| 37 | + |
| 38 | +logger = logging.getLogger(__name__) |
| 39 | + |
| 40 | + |
| 41 | +# Helper Function (can stay here or move to utils) |
| 42 | +def _get_bounds_center(bounds: Bounds) -> Optional[Tuple[float, float]]: |
| 43 | + """Calculate the center (relative coords) of a bounding box.""" |
| 44 | + if not isinstance(bounds, (list, tuple)) or len(bounds) != 4: |
| 45 | + logger.warning( |
| 46 | + f"Invalid bounds format received: {bounds}. Cannot calculate center." |
| 47 | + ) |
| 48 | + return None |
| 49 | + x, y, w, h = bounds |
| 50 | + # Ensure w and h are non-negative |
| 51 | + if w < 0 or h < 0: |
| 52 | + logger.warning( |
| 53 | + f"Invalid bounds dimensions (w={w}, h={h}). Cannot calculate center." |
| 54 | + ) |
| 55 | + return None |
| 56 | + return x + w / 2, y + h / 2 |
| 57 | + |
| 58 | + |
| 59 | +class SimpleElementTracker: |
| 60 | + """ |
| 61 | + Basic element tracking across frames based on type and proximity using optimal assignment. |
| 62 | + Assigns persistent track_ids. |
| 63 | + """ |
| 64 | + |
| 65 | + def __init__( |
| 66 | + self, miss_threshold: int = 3, matching_threshold: float = 0.1 |
| 67 | + ): # Increased threshold slightly |
| 68 | + """ |
| 69 | + Args: |
| 70 | + miss_threshold: How many consecutive misses before pruning a track. |
| 71 | + matching_threshold: Relative distance threshold for matching centers. |
| 72 | + """ |
| 73 | + if not SCIPY_AVAILABLE: |
| 74 | + # Optionally raise an error or disable tracking features |
| 75 | + logger.error( |
| 76 | + "Scipy is required for SimpleElementTracker matching logic but not installed." |
| 77 | + ) |
| 78 | + # raise ImportError("Scipy is required for SimpleElementTracker") |
| 79 | + self.tracked_elements: Dict[str, ElementTrack] = {} # track_id -> ElementTrack |
| 80 | + self.next_track_id_counter: int = 0 |
| 81 | + self.miss_threshold = miss_threshold |
| 82 | + # Store squared threshold for efficiency |
| 83 | + self.match_threshold_sq = matching_threshold**2 |
| 84 | + logger.info( |
| 85 | + f"SimpleElementTracker initialized (miss_thresh={miss_threshold}, match_dist_sq={self.match_threshold_sq:.4f})." |
| 86 | + ) |
| 87 | + |
| 88 | + def _generate_track_id(self) -> str: |
| 89 | + """Generates a unique track ID.""" |
| 90 | + track_id = f"track_{self.next_track_id_counter}" |
| 91 | + self.next_track_id_counter += 1 |
| 92 | + return track_id |
| 93 | + |
| 94 | + def _match_elements(self, current_elements: List[UIElement]) -> Dict[int, str]: |
| 95 | + """ |
| 96 | + Performs optimal assignment matching between current elements and active tracks. |
| 97 | +
|
| 98 | + Args: |
| 99 | + current_elements: List of UIElements detected in the current frame. |
| 100 | +
|
| 101 | + Returns: |
| 102 | + Dict[int, str]: A mapping from current_element.id to matched track_id. |
| 103 | + Only includes elements that were successfully matched. |
| 104 | + """ |
| 105 | + if not SCIPY_AVAILABLE: |
| 106 | + logger.warning("Scipy not available, skipping matching.") |
| 107 | + return {} |
| 108 | + if not current_elements or not self.tracked_elements: |
| 109 | + return {} # Nothing to match |
| 110 | + |
| 111 | + # --- Prepare Data for Matching --- |
| 112 | + active_tracks = [ |
| 113 | + track |
| 114 | + for track in self.tracked_elements.values() |
| 115 | + if track.latest_element is not None # Only match tracks currently visible |
| 116 | + ] |
| 117 | + if not active_tracks: |
| 118 | + return {} # No active tracks to match against |
| 119 | + |
| 120 | + # current_element_map = {el.id: el for el in current_elements} |
| 121 | + # track_map = {track.track_id: track for track in active_tracks} |
| 122 | + |
| 123 | + # Get centers and types for cost calculation |
| 124 | + current_centers = np.array( |
| 125 | + [ |
| 126 | + _get_bounds_center(el.bounds) |
| 127 | + for el in current_elements |
| 128 | + if _get_bounds_center(el.bounds) is not None # Filter invalid bounds |
| 129 | + ] |
| 130 | + ) |
| 131 | + current_types = [ |
| 132 | + el.type |
| 133 | + for el in current_elements |
| 134 | + if _get_bounds_center(el.bounds) is not None |
| 135 | + ] |
| 136 | + current_ids_valid = [ |
| 137 | + el.id |
| 138 | + for el in current_elements |
| 139 | + if _get_bounds_center(el.bounds) is not None |
| 140 | + ] |
| 141 | + |
| 142 | + track_centers = np.array( |
| 143 | + [ |
| 144 | + _get_bounds_center(track.latest_element.bounds) |
| 145 | + for track in active_tracks |
| 146 | + if track.latest_element |
| 147 | + and _get_bounds_center(track.latest_element.bounds) is not None |
| 148 | + ] |
| 149 | + ) |
| 150 | + track_types = [ |
| 151 | + track.latest_element.type |
| 152 | + for track in active_tracks |
| 153 | + if track.latest_element |
| 154 | + and _get_bounds_center(track.latest_element.bounds) is not None |
| 155 | + ] |
| 156 | + track_ids_valid = [ |
| 157 | + track.track_id |
| 158 | + for track in active_tracks |
| 159 | + if track.latest_element |
| 160 | + and _get_bounds_center(track.latest_element.bounds) is not None |
| 161 | + ] |
| 162 | + |
| 163 | + if current_centers.size == 0 or track_centers.size == 0: |
| 164 | + logger.debug("No valid centers for matching.") |
| 165 | + return {} # Cannot match if no valid centers |
| 166 | + |
| 167 | + # --- Calculate Cost Matrix (Squared Euclidean Distance) --- |
| 168 | + # Cost matrix: rows = current elements, cols = active tracks |
| 169 | + cost_matrix = cdist(current_centers, track_centers, metric="sqeuclidean") |
| 170 | + |
| 171 | + # --- Apply Constraints (Type Mismatch & Distance Threshold) --- |
| 172 | + infinity_cost = float("inf") |
| 173 | + num_current, num_tracks = cost_matrix.shape |
| 174 | + |
| 175 | + for i in range(num_current): |
| 176 | + for j in range(num_tracks): |
| 177 | + # Infinite cost if types don't match |
| 178 | + if current_types[i] != track_types[j]: |
| 179 | + cost_matrix[i, j] = infinity_cost |
| 180 | + # Infinite cost if distance exceeds threshold |
| 181 | + elif cost_matrix[i, j] > self.match_threshold_sq: |
| 182 | + cost_matrix[i, j] = infinity_cost |
| 183 | + |
| 184 | + # --- Optimal Assignment using Hungarian Algorithm --- |
| 185 | + try: |
| 186 | + row_ind, col_ind = linear_sum_assignment(cost_matrix) |
| 187 | + except ValueError as e: |
| 188 | + logger.error( |
| 189 | + f"Error during linear_sum_assignment: {e}. Cost matrix shape: {cost_matrix.shape}" |
| 190 | + ) |
| 191 | + return {} |
| 192 | + |
| 193 | + # --- Create Mapping from Valid Assignments --- |
| 194 | + assignment_mapping: Dict[int, str] = {} # current_element_id -> track_id |
| 195 | + valid_matches_count = 0 |
| 196 | + for r, c in zip(row_ind, col_ind): |
| 197 | + # Check if the assignment cost is valid (not infinity) |
| 198 | + if cost_matrix[r, c] < infinity_cost: |
| 199 | + current_element_id = current_ids_valid[r] |
| 200 | + track_id = track_ids_valid[c] |
| 201 | + assignment_mapping[current_element_id] = track_id |
| 202 | + valid_matches_count += 1 |
| 203 | + |
| 204 | + logger.debug(f"Matching: Found {valid_matches_count} valid assignments.") |
| 205 | + return assignment_mapping |
| 206 | + |
| 207 | + def update( |
| 208 | + self, current_elements: List[UIElement], frame_number: int |
| 209 | + ) -> List[ElementTrack]: |
| 210 | + """ |
| 211 | + Updates tracks based on current detections using optimal assignment matching. |
| 212 | +
|
| 213 | + Args: |
| 214 | + current_elements: List of UIElements detected in the current frame. |
| 215 | + frame_number: The current step/frame number. |
| 216 | +
|
| 217 | + Returns: |
| 218 | + A list of all currently active ElementTrack objects (including missed ones). |
| 219 | + """ |
| 220 | + current_element_map = {el.id: el for el in current_elements} |
| 221 | + |
| 222 | + # Get the mapping: current_element_id -> track_id |
| 223 | + assignment_mapping = self._match_elements(current_elements) |
| 224 | + |
| 225 | + matched_current_element_ids = set(assignment_mapping.keys()) |
| 226 | + matched_track_ids = set(assignment_mapping.values()) |
| 227 | + |
| 228 | + tracks_to_prune: List[str] = [] |
| 229 | + # Update existing tracks based on matches |
| 230 | + for track_id, track in self.tracked_elements.items(): |
| 231 | + if track_id in matched_track_ids: |
| 232 | + # Find the current element that matched this track |
| 233 | + matched_elem_id = next( |
| 234 | + ( |
| 235 | + curr_id |
| 236 | + for curr_id, t_id in assignment_mapping.items() |
| 237 | + if t_id == track_id |
| 238 | + ), |
| 239 | + None, |
| 240 | + ) |
| 241 | + |
| 242 | + if ( |
| 243 | + matched_elem_id is not None |
| 244 | + and matched_elem_id in current_element_map |
| 245 | + ): |
| 246 | + # Matched successfully |
| 247 | + track.latest_element = current_element_map[matched_elem_id] |
| 248 | + track.consecutive_misses = 0 |
| 249 | + track.last_seen_frame = frame_number |
| 250 | + else: |
| 251 | + # Match found in assignment but element missing from map (should not happen ideally) |
| 252 | + logger.warning( |
| 253 | + f"Track {track_id} matched but element ID {matched_elem_id} not found in current_element_map. Treating as miss." |
| 254 | + ) |
| 255 | + track.latest_element = None |
| 256 | + track.consecutive_misses += 1 |
| 257 | + logger.debug( |
| 258 | + f"Track {track_id} treated as missed frame {frame_number}. Consecutive misses: {track.consecutive_misses}" |
| 259 | + ) |
| 260 | + if track.consecutive_misses >= self.miss_threshold: |
| 261 | + tracks_to_prune.append(track_id) |
| 262 | + else: |
| 263 | + # Track was not matched in the current frame |
| 264 | + track.latest_element = None |
| 265 | + track.consecutive_misses += 1 |
| 266 | + logger.debug( |
| 267 | + f"Track {track_id} missed frame {frame_number}. Consecutive misses: {track.consecutive_misses}" |
| 268 | + ) |
| 269 | + # Check for pruning AFTER incrementing misses |
| 270 | + if track.consecutive_misses >= self.miss_threshold: |
| 271 | + tracks_to_prune.append(track_id) |
| 272 | + |
| 273 | + # Prune tracks marked for deletion |
| 274 | + for track_id in tracks_to_prune: |
| 275 | + logger.debug( |
| 276 | + f"Pruning track {track_id} after {self.tracked_elements[track_id].consecutive_misses} misses." |
| 277 | + ) |
| 278 | + if track_id in self.tracked_elements: |
| 279 | + del self.tracked_elements[track_id] |
| 280 | + |
| 281 | + # Add tracks for new, unmatched elements |
| 282 | + for element_id, element in current_element_map.items(): |
| 283 | + if element_id not in matched_current_element_ids: |
| 284 | + # Ensure element has valid bounds before creating track |
| 285 | + if _get_bounds_center(element.bounds) is None: |
| 286 | + logger.debug( |
| 287 | + f"Skipping creation of track for element ID {element_id} due to invalid bounds." |
| 288 | + ) |
| 289 | + continue |
| 290 | + |
| 291 | + new_track_id = self._generate_track_id() |
| 292 | + new_track = ElementTrack( |
| 293 | + track_id=new_track_id, |
| 294 | + latest_element=element, |
| 295 | + consecutive_misses=0, |
| 296 | + last_seen_frame=frame_number, |
| 297 | + ) |
| 298 | + self.tracked_elements[new_track_id] = new_track |
| 299 | + logger.debug( |
| 300 | + f"Created new track {new_track_id} for element ID {element_id}" |
| 301 | + ) |
| 302 | + |
| 303 | + # Return the current list of all tracked elements' state |
| 304 | + return list(self.tracked_elements.values()) |
0 commit comments