Skip to content

Commit 9f9e4ed

Browse files
dominik737HonzaCuhelkkeroo
authored
Feat: Gather Data (#203)
* Add DetectionsRecognitionsSync * Update typings to support py3.8 * Update DetectionsRecognitionsSync & tests * Fix input, output * Fix tests. * Pre commit fix. * Update DetectionsRecognitions * Rename property nn_data -> recognitions_data * Format code * Import update. * _node appendix. * Rename DetectionsRecognitionsSync -> TwoStageSync & refactor * remove the private property docs * rename to GatherData * file rename to gather data * move fps validation to set camera fps * camera fps changed to optional * rename detected recogntions to gathered data * rename gathered data file * generic gather data * rename collected data to gathered data * gather data variables renamed accordingly to more general use-case * gather data changeable wait for count function * test build method * rename test file * queue get all method * test img detections * fix the queue wrong order from lifo to fifo * test img detections * test set wait count fn * test run without build * gather data set timestamp based on reference data * verify gather datata timestamp correctly set * test clear old data * remove unused image detections extended * refactoring test gather data * old references renamed * wait count fn settable in build method * fix python 3.8 default arg issue * lambdas instead of get all in the test * add node suffix to the test file * build method docstrings * gather data docstring update * update messsage docstrings --------- Co-authored-by: HonzaCuhel <[email protected]> Co-authored-by: kkeroo <[email protected]> Co-authored-by: Jan Čuhel <[email protected]>
1 parent c58ed4f commit 9f9e4ed

File tree

8 files changed

+618
-6
lines changed

8 files changed

+618
-6
lines changed

depthai_nodes/message/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .classification import Classifications
22
from .clusters import Cluster, Clusters
3+
from .gathered_data import GatheredData
34
from .img_detections import (
45
ImgDetectionExtended,
56
ImgDetectionsExtended,
@@ -24,4 +25,5 @@
2425
"Cluster",
2526
"Prediction",
2627
"Predictions",
28+
"GatheredData",
2729
]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Generic, List, TypeVar
2+
3+
import depthai as dai
4+
5+
TReference = TypeVar("TReference", bound=dai.Buffer)
6+
TGathered = TypeVar("TGathered")
7+
8+
9+
class GatheredData(dai.Buffer, Generic[TReference, TGathered]):
10+
"""Contains N messages and reference data that the messages were matched with.
11+
12+
Attributes
13+
----------
14+
reference_data: TReference
15+
Data that is used to determine how many of TGathered to gather.
16+
collected: List[TGathered]
17+
List of collected data.
18+
"""
19+
20+
def __init__(self, reference_data: TReference, gathered: List[TGathered]) -> None:
21+
"""Initializes the GatheredData object."""
22+
super().__init__()
23+
self.reference_data = reference_data
24+
self.gathered = gathered
25+
26+
@property
27+
def reference_data(self) -> TReference:
28+
"""Returns the reference data.
29+
30+
@return: Reference data.
31+
@rtype: TReference
32+
"""
33+
return self._reference_data
34+
35+
@reference_data.setter
36+
def reference_data(self, value: TReference):
37+
"""Sets the reference data.
38+
39+
@param value: Reference data.
40+
@type value: TReference
41+
"""
42+
self.setSequenceNum(value.getSequenceNum())
43+
self.setTimestamp(value.getTimestamp())
44+
self.setTimestampDevice(value.getTimestampDevice())
45+
self._reference_data = value
46+
47+
@property
48+
def gathered(self) -> List[TGathered]:
49+
"""Returns the collected data.
50+
51+
@return: List of collected data.
52+
@rtype: List[TGathered]
53+
"""
54+
return self._gathered
55+
56+
@gathered.setter
57+
def gathered(self, value: List[TGathered]):
58+
"""Sets the gathered data.
59+
60+
@param value: List of gathered data.
61+
@type value: List[TGathered]
62+
@raise TypeError: If value is not a list.
63+
"""
64+
if not isinstance(value, list):
65+
raise TypeError("gathered_data must be a list.")
66+
self._gathered = value

depthai_nodes/node/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .apply_colormap import ApplyColormap
22
from .base_host_node import BaseHostNode
33
from .depth_merger import DepthMerger
4+
from .gather_data import GatherData
45
from .host_parsing_neural_network import HostParsingNeuralNetwork
56
from .host_spatials_calc import HostSpatialsCalc
67
from .img_detections_bridge import ImgDetectionsBridge
@@ -63,6 +64,7 @@
6364
"BaseParser",
6465
"DetectionParser",
6566
"EmbeddingsParser",
67+
"GatherData",
6668
"ImgFrameOverlay",
6769
"ImgDetectionsBridge",
6870
"ImgDetectionsFilter",

depthai_nodes/node/gather_data.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import time
2+
from queue import PriorityQueue
3+
from typing import (
4+
Callable,
5+
Dict,
6+
Generic,
7+
List,
8+
Optional,
9+
Protocol,
10+
TypeVar,
11+
runtime_checkable,
12+
)
13+
14+
import depthai as dai
15+
16+
from depthai_nodes import GatheredData
17+
18+
19+
@runtime_checkable
20+
class HasDetections(Protocol):
21+
@property
22+
def detections(self) -> List:
23+
...
24+
25+
26+
TReference = TypeVar("TReference", bound=dai.Buffer)
27+
TGathered = TypeVar("TGathered", bound=dai.Buffer)
28+
29+
30+
class GatherData(dai.node.ThreadedHostNode, Generic[TReference, TGathered]):
31+
FPS_TOLERANCE_DIVISOR = 2.0
32+
INPUT_CHECKS_PER_FPS = 100
33+
"""A class for gathering data. Gathers n messages based on reference_data. To
34+
determine n, wait_count_fn function is used. The default wait_count_fn function is
35+
waiting for len(TReference.detection). This means the node works out-of-the-box with
36+
dai.ImgDetections and ImgDetectionsExtended.
37+
38+
Attributes
39+
----------
40+
FPS_TOLERANCE_DIVISOR: float
41+
Divisor for the FPS tolerance.
42+
INPUT_CHECKS_PER_FPS: int
43+
Number of input checks per FPS.
44+
input_data: dai.Node.Input
45+
Input to be gathered.
46+
input_reference: dai.Node.Input
47+
Input to determine how many gathered items to wait for.
48+
output: dai.Node.Output
49+
Output for gathered data.
50+
"""
51+
52+
def __init__(self) -> None:
53+
"""Initializes the GatherData node."""
54+
self._camera_fps: Optional[int] = None
55+
self._unmatched_data: List[TGathered] = []
56+
self._data_by_reference_ts: Dict[float, List[TGathered]] = {}
57+
self._reference_data: Dict[float, TReference] = {}
58+
self._ready_timestamps = PriorityQueue()
59+
self._wait_count_fn = self._default_wait_count_fn
60+
61+
self.input_data = self.createInput()
62+
self.input_reference = self.createInput()
63+
self.out = self.createOutput()
64+
65+
@staticmethod
66+
def _default_wait_count_fn(reference: TReference) -> int:
67+
assert isinstance(reference, HasDetections)
68+
return len(reference.detections)
69+
70+
def build(
71+
self,
72+
camera_fps: int,
73+
wait_count_fn: Optional[Callable[[TReference], int]] = None,
74+
) -> "GatherData[TReference, TGathered]":
75+
"""Builds and configures the GatherData node with the specified parameters.
76+
77+
Parameters
78+
----------
79+
camera_fps : int
80+
The frames per second (FPS) setting for the camera. This affects the rate
81+
at which data is gathered.
82+
83+
wait_count_fn : Optional[Callable[[TReference], int]], default=None
84+
A function that takes a reference and returns the number of frames to wait
85+
before gathering data. This allows customizing the waiting behavior based on the reference data.
86+
If None, the default wait count function will be used. The default function matches based on length of TReference.detections array.
87+
88+
Returns
89+
-------
90+
GatherData[TReference, TGathered]
91+
The configured GatherData node instance.
92+
93+
Examples
94+
--------
95+
>>> gather_node = GatherData()
96+
>>> # Build with default wait count function
97+
>>> gather_node.build(camera_fps=30)
98+
>>>
99+
>>> # Build with custom wait count function that always waits for 2 messages
100+
>>> def custom_wait(ref):
101+
>>> return 2
102+
>>> gather_node.build(camera_fps=60, wait_count_fn=custom_wait)
103+
"""
104+
self.set_camera_fps(camera_fps)
105+
if wait_count_fn is None:
106+
wait_count_fn = self._default_wait_count_fn
107+
self.set_wait_count_fn(wait_count_fn)
108+
return self
109+
110+
def set_camera_fps(self, fps: int) -> None:
111+
if fps <= 0:
112+
raise ValueError(f"Camera FPS must be positive, got {fps}")
113+
self._camera_fps = fps
114+
115+
def run(self) -> None:
116+
if not self._camera_fps:
117+
raise ValueError("Camera FPS not set. Call build() before run().")
118+
119+
while self.isRunning():
120+
try:
121+
input_data: TGathered = self.input_data.tryGet()
122+
input_reference: TReference = self.input_reference.tryGet()
123+
except dai.MessageQueue.QueueException:
124+
break
125+
if input_data:
126+
self._add_data(input_data)
127+
self._send_ready_data()
128+
if input_reference:
129+
self._add_reference(input_reference)
130+
self._send_ready_data()
131+
132+
time.sleep(1 / self.INPUT_CHECKS_PER_FPS / self._camera_fps)
133+
134+
def _send_ready_data(self) -> None:
135+
ready_data = self._pop_ready_data()
136+
if ready_data:
137+
self._clear_old_data(ready_data)
138+
self.out.send(ready_data)
139+
140+
def _add_data(self, data: TGathered) -> None:
141+
data_ts = self._get_total_seconds_ts(data)
142+
best_matching_reference_ts = self._get_matching_reference_ts(data_ts)
143+
144+
if best_matching_reference_ts is not None:
145+
self._add_data_by_reference_ts(data, best_matching_reference_ts)
146+
self._update_ready_timestamps(best_matching_reference_ts)
147+
else:
148+
self._unmatched_data.append(data)
149+
150+
def _get_matching_reference_ts(self, data_ts: float) -> Optional[float]:
151+
for reference_ts in self._reference_data.keys():
152+
if self._timestamps_in_tolerance(reference_ts, data_ts):
153+
return reference_ts
154+
return None
155+
156+
def _add_reference(
157+
self,
158+
reference: TReference,
159+
) -> None:
160+
reference_ts = self._get_total_seconds_ts(reference)
161+
self._reference_data[reference_ts] = reference
162+
self._try_match_data(reference_ts)
163+
self._update_ready_timestamps(reference_ts)
164+
165+
def _try_match_data(self, reference_ts: float) -> None:
166+
matched_data: List[TGathered] = []
167+
for data in self._unmatched_data:
168+
data_ts = self._get_total_seconds_ts(data)
169+
if self._timestamps_in_tolerance(reference_ts, data_ts):
170+
self._add_data_by_reference_ts(data, reference_ts)
171+
matched_data.append(data)
172+
173+
for matched in matched_data:
174+
self._unmatched_data.remove(matched)
175+
176+
def _timestamps_in_tolerance(self, timestamp1: float, timestamp2: float) -> bool:
177+
difference = abs(timestamp1 - timestamp2)
178+
return difference < (1 / self._camera_fps / self.FPS_TOLERANCE_DIVISOR)
179+
180+
def _add_data_by_reference_ts(self, data: TGathered, reference_ts: float) -> None:
181+
if reference_ts in self._data_by_reference_ts:
182+
self._data_by_reference_ts[reference_ts].append(data)
183+
else:
184+
self._data_by_reference_ts[reference_ts] = [data]
185+
186+
def _update_ready_timestamps(self, timestamp: float) -> None:
187+
if not self._timestamp_ready(timestamp):
188+
return
189+
190+
self._ready_timestamps.put(timestamp)
191+
192+
def _timestamp_ready(self, timestamp: float) -> bool:
193+
reference = self._reference_data.get(timestamp)
194+
if not reference:
195+
return False
196+
197+
wait_for_count = self._get_wait_count(reference)
198+
if wait_for_count == 0:
199+
return True
200+
201+
recognitions = self._data_by_reference_ts.get(timestamp)
202+
if not recognitions:
203+
return False
204+
205+
return wait_for_count == len(recognitions)
206+
207+
def _get_wait_count(self, reference: TReference) -> int:
208+
return self._wait_count_fn(reference)
209+
210+
def _pop_ready_data(self) -> Optional[GatheredData]:
211+
if self._ready_timestamps.empty():
212+
return None
213+
214+
timestamp = self._ready_timestamps.get()
215+
return GatheredData(
216+
reference_data=self._reference_data.pop(timestamp),
217+
gathered=self._data_by_reference_ts.pop(timestamp, None) or [],
218+
)
219+
220+
def _clear_old_data(self, ready_data: GatheredData) -> None:
221+
current_timestamp = self._get_total_seconds_ts(ready_data)
222+
self._clear_unmatched_data(current_timestamp)
223+
self._clear_old_references(current_timestamp)
224+
225+
def _clear_unmatched_data(self, current_timestamp: float) -> None:
226+
unmatched_data_to_remove = []
227+
for unmatched_data in self._unmatched_data:
228+
if self._get_total_seconds_ts(unmatched_data) < current_timestamp:
229+
unmatched_data_to_remove.append(unmatched_data)
230+
231+
for unmatched_data in unmatched_data_to_remove:
232+
self._unmatched_data.remove(unmatched_data)
233+
234+
def _get_total_seconds_ts(self, buffer_like: dai.Buffer) -> float:
235+
return buffer_like.getTimestamp().total_seconds()
236+
237+
def _clear_old_references(self, current_timestamp: float) -> None:
238+
reference_keys_to_pop = []
239+
for reference_ts in self._reference_data.keys():
240+
if reference_ts < current_timestamp:
241+
reference_keys_to_pop.append(reference_ts)
242+
243+
for reference_ts in reference_keys_to_pop:
244+
self._reference_data.pop(reference_ts)
245+
self._data_by_reference_ts.pop(reference_ts, None)
246+
247+
def set_wait_count_fn(self, fn: Callable[[TReference], int]) -> None:
248+
self._wait_count_fn = fn

tests/stability_tests/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ You dont need to manually create the mock classes since they are implemented in
9292

9393
Check the already implemented tests for reference.
9494

95-
- `ThreadedHostNode` - tests for `DetectionsRecognitionsSync`
95+
- `ThreadedHostNode` - tests for `GatherData`
9696
- `HostNode` - tests for `TilesPatcher`
9797

9898
You can check if everything works by running the tests locally. To run the unit tests move to the `depthai_nodes/tests` directory and run the tests with `pytest`.

0 commit comments

Comments
 (0)