Skip to content

Commit 7ec6783

Browse files
committed
feat: Add initial tracking infrastructure and metrics logging
- Defines Pydantic models (ElementTrack, ScreenAnalysis, ActionDecision, LoggedStep) in types.py based on Issue #8 design. - Implements SimpleElementTracker skeleton in tracking.py with fixed update logic for misses/pruning. Matching logic (_match_elements) remains placeholder. - Adds basic passing unit tests for SimpleElementTracker in tests/test_tracking.py. - Integrates metrics collection (step times, counts, etc.) and structured JSONL logging (LoggedStep format) into AgentExecutor. - Adds scipy and numpy as dependencies. This lays the groundwork for implementing robust element tracking (Issue #8) and addresses the need for improved observability and data logging.
1 parent 7a562a1 commit 7ec6783

File tree

5 files changed

+768
-58
lines changed

5 files changed

+768
-58
lines changed

omnimcp/tracking.py

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# omnimcp/tracking.py
2+
from typing import List, Dict, Optional, Tuple
3+
4+
# Use typing_extensions for Self if needed for older Python versions
5+
# from typing_extensions import Self
6+
7+
# Added Scipy for matching
8+
import numpy as np
9+
10+
try:
11+
from scipy.optimize import linear_sum_assignment
12+
from scipy.spatial.distance import cdist
13+
14+
SCIPY_AVAILABLE = True
15+
except ImportError:
16+
SCIPY_AVAILABLE = False
17+
# Fallback or warning needed if scipy is critical
18+
import warnings
19+
20+
warnings.warn(
21+
"Scipy not found. Tracking matching will be disabled or use a fallback."
22+
)
23+
24+
25+
# Assuming UIElement and ElementTrack are defined in omnimcp.types
26+
try:
27+
from omnimcp.types import UIElement, ElementTrack, Bounds
28+
except ImportError:
29+
print("Warning: Could not import types from omnimcp.types")
30+
UIElement = dict # type: ignore
31+
ElementTrack = dict # type: ignore
32+
Bounds = tuple # type: ignore
33+
34+
# Assuming logger is setup elsewhere and accessible, or use standard logging
35+
# from omnimcp.utils import logger
36+
import logging
37+
38+
logger = logging.getLogger(__name__)
39+
40+
41+
# Helper Function (can stay here or move to utils)
42+
def _get_bounds_center(bounds: Bounds) -> Optional[Tuple[float, float]]:
43+
"""Calculate the center (relative coords) of a bounding box."""
44+
if not isinstance(bounds, (list, tuple)) or len(bounds) != 4:
45+
logger.warning(
46+
f"Invalid bounds format received: {bounds}. Cannot calculate center."
47+
)
48+
return None
49+
x, y, w, h = bounds
50+
# Ensure w and h are non-negative
51+
if w < 0 or h < 0:
52+
logger.warning(
53+
f"Invalid bounds dimensions (w={w}, h={h}). Cannot calculate center."
54+
)
55+
return None
56+
return x + w / 2, y + h / 2
57+
58+
59+
class SimpleElementTracker:
60+
"""
61+
Basic element tracking across frames based on type and proximity using optimal assignment.
62+
Assigns persistent track_ids.
63+
"""
64+
65+
def __init__(
66+
self, miss_threshold: int = 3, matching_threshold: float = 0.1
67+
): # Increased threshold slightly
68+
"""
69+
Args:
70+
miss_threshold: How many consecutive misses before pruning a track.
71+
matching_threshold: Relative distance threshold for matching centers.
72+
"""
73+
if not SCIPY_AVAILABLE:
74+
# Optionally raise an error or disable tracking features
75+
logger.error(
76+
"Scipy is required for SimpleElementTracker matching logic but not installed."
77+
)
78+
# raise ImportError("Scipy is required for SimpleElementTracker")
79+
self.tracked_elements: Dict[str, ElementTrack] = {} # track_id -> ElementTrack
80+
self.next_track_id_counter: int = 0
81+
self.miss_threshold = miss_threshold
82+
# Store squared threshold for efficiency
83+
self.match_threshold_sq = matching_threshold**2
84+
logger.info(
85+
f"SimpleElementTracker initialized (miss_thresh={miss_threshold}, match_dist_sq={self.match_threshold_sq:.4f})."
86+
)
87+
88+
def _generate_track_id(self) -> str:
89+
"""Generates a unique track ID."""
90+
track_id = f"track_{self.next_track_id_counter}"
91+
self.next_track_id_counter += 1
92+
return track_id
93+
94+
def _match_elements(self, current_elements: List[UIElement]) -> Dict[int, str]:
95+
"""
96+
Performs optimal assignment matching between current elements and active tracks.
97+
98+
Args:
99+
current_elements: List of UIElements detected in the current frame.
100+
101+
Returns:
102+
Dict[int, str]: A mapping from current_element.id to matched track_id.
103+
Only includes elements that were successfully matched.
104+
"""
105+
if not SCIPY_AVAILABLE:
106+
logger.warning("Scipy not available, skipping matching.")
107+
return {}
108+
if not current_elements or not self.tracked_elements:
109+
return {} # Nothing to match
110+
111+
# --- Prepare Data for Matching ---
112+
active_tracks = [
113+
track
114+
for track in self.tracked_elements.values()
115+
if track.latest_element is not None # Only match tracks currently visible
116+
]
117+
if not active_tracks:
118+
return {} # No active tracks to match against
119+
120+
# current_element_map = {el.id: el for el in current_elements}
121+
# track_map = {track.track_id: track for track in active_tracks}
122+
123+
# Get centers and types for cost calculation
124+
current_centers = np.array(
125+
[
126+
_get_bounds_center(el.bounds)
127+
for el in current_elements
128+
if _get_bounds_center(el.bounds) is not None # Filter invalid bounds
129+
]
130+
)
131+
current_types = [
132+
el.type
133+
for el in current_elements
134+
if _get_bounds_center(el.bounds) is not None
135+
]
136+
current_ids_valid = [
137+
el.id
138+
for el in current_elements
139+
if _get_bounds_center(el.bounds) is not None
140+
]
141+
142+
track_centers = np.array(
143+
[
144+
_get_bounds_center(track.latest_element.bounds)
145+
for track in active_tracks
146+
if track.latest_element
147+
and _get_bounds_center(track.latest_element.bounds) is not None
148+
]
149+
)
150+
track_types = [
151+
track.latest_element.type
152+
for track in active_tracks
153+
if track.latest_element
154+
and _get_bounds_center(track.latest_element.bounds) is not None
155+
]
156+
track_ids_valid = [
157+
track.track_id
158+
for track in active_tracks
159+
if track.latest_element
160+
and _get_bounds_center(track.latest_element.bounds) is not None
161+
]
162+
163+
if current_centers.size == 0 or track_centers.size == 0:
164+
logger.debug("No valid centers for matching.")
165+
return {} # Cannot match if no valid centers
166+
167+
# --- Calculate Cost Matrix (Squared Euclidean Distance) ---
168+
# Cost matrix: rows = current elements, cols = active tracks
169+
cost_matrix = cdist(current_centers, track_centers, metric="sqeuclidean")
170+
171+
# --- Apply Constraints (Type Mismatch & Distance Threshold) ---
172+
infinity_cost = float("inf")
173+
num_current, num_tracks = cost_matrix.shape
174+
175+
for i in range(num_current):
176+
for j in range(num_tracks):
177+
# Infinite cost if types don't match
178+
if current_types[i] != track_types[j]:
179+
cost_matrix[i, j] = infinity_cost
180+
# Infinite cost if distance exceeds threshold
181+
elif cost_matrix[i, j] > self.match_threshold_sq:
182+
cost_matrix[i, j] = infinity_cost
183+
184+
# --- Optimal Assignment using Hungarian Algorithm ---
185+
try:
186+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
187+
except ValueError as e:
188+
logger.error(
189+
f"Error during linear_sum_assignment: {e}. Cost matrix shape: {cost_matrix.shape}"
190+
)
191+
return {}
192+
193+
# --- Create Mapping from Valid Assignments ---
194+
assignment_mapping: Dict[int, str] = {} # current_element_id -> track_id
195+
valid_matches_count = 0
196+
for r, c in zip(row_ind, col_ind):
197+
# Check if the assignment cost is valid (not infinity)
198+
if cost_matrix[r, c] < infinity_cost:
199+
current_element_id = current_ids_valid[r]
200+
track_id = track_ids_valid[c]
201+
assignment_mapping[current_element_id] = track_id
202+
valid_matches_count += 1
203+
204+
logger.debug(f"Matching: Found {valid_matches_count} valid assignments.")
205+
return assignment_mapping
206+
207+
def update(
208+
self, current_elements: List[UIElement], frame_number: int
209+
) -> List[ElementTrack]:
210+
"""
211+
Updates tracks based on current detections using optimal assignment matching.
212+
213+
Args:
214+
current_elements: List of UIElements detected in the current frame.
215+
frame_number: The current step/frame number.
216+
217+
Returns:
218+
A list of all currently active ElementTrack objects (including missed ones).
219+
"""
220+
current_element_map = {el.id: el for el in current_elements}
221+
222+
# Get the mapping: current_element_id -> track_id
223+
assignment_mapping = self._match_elements(current_elements)
224+
225+
matched_current_element_ids = set(assignment_mapping.keys())
226+
matched_track_ids = set(assignment_mapping.values())
227+
228+
tracks_to_prune: List[str] = []
229+
# Update existing tracks based on matches
230+
for track_id, track in self.tracked_elements.items():
231+
if track_id in matched_track_ids:
232+
# Find the current element that matched this track
233+
matched_elem_id = next(
234+
(
235+
curr_id
236+
for curr_id, t_id in assignment_mapping.items()
237+
if t_id == track_id
238+
),
239+
None,
240+
)
241+
242+
if (
243+
matched_elem_id is not None
244+
and matched_elem_id in current_element_map
245+
):
246+
# Matched successfully
247+
track.latest_element = current_element_map[matched_elem_id]
248+
track.consecutive_misses = 0
249+
track.last_seen_frame = frame_number
250+
else:
251+
# Match found in assignment but element missing from map (should not happen ideally)
252+
logger.warning(
253+
f"Track {track_id} matched but element ID {matched_elem_id} not found in current_element_map. Treating as miss."
254+
)
255+
track.latest_element = None
256+
track.consecutive_misses += 1
257+
logger.debug(
258+
f"Track {track_id} treated as missed frame {frame_number}. Consecutive misses: {track.consecutive_misses}"
259+
)
260+
if track.consecutive_misses >= self.miss_threshold:
261+
tracks_to_prune.append(track_id)
262+
else:
263+
# Track was not matched in the current frame
264+
track.latest_element = None
265+
track.consecutive_misses += 1
266+
logger.debug(
267+
f"Track {track_id} missed frame {frame_number}. Consecutive misses: {track.consecutive_misses}"
268+
)
269+
# Check for pruning AFTER incrementing misses
270+
if track.consecutive_misses >= self.miss_threshold:
271+
tracks_to_prune.append(track_id)
272+
273+
# Prune tracks marked for deletion
274+
for track_id in tracks_to_prune:
275+
logger.debug(
276+
f"Pruning track {track_id} after {self.tracked_elements[track_id].consecutive_misses} misses."
277+
)
278+
if track_id in self.tracked_elements:
279+
del self.tracked_elements[track_id]
280+
281+
# Add tracks for new, unmatched elements
282+
for element_id, element in current_element_map.items():
283+
if element_id not in matched_current_element_ids:
284+
# Ensure element has valid bounds before creating track
285+
if _get_bounds_center(element.bounds) is None:
286+
logger.debug(
287+
f"Skipping creation of track for element ID {element_id} due to invalid bounds."
288+
)
289+
continue
290+
291+
new_track_id = self._generate_track_id()
292+
new_track = ElementTrack(
293+
track_id=new_track_id,
294+
latest_element=element,
295+
consecutive_misses=0,
296+
last_seen_frame=frame_number,
297+
)
298+
self.tracked_elements[new_track_id] = new_track
299+
logger.debug(
300+
f"Created new track {new_track_id} for element ID {element_id}"
301+
)
302+
303+
# Return the current list of all tracked elements' state
304+
return list(self.tracked_elements.values())

0 commit comments

Comments
 (0)