Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d1720a2
Add DetectionsRecognitionsSync
HonzaCuhel Mar 24, 2025
e015fba
Update typings to support py3.8
HonzaCuhel Mar 24, 2025
fce76e0
Update DetectionsRecognitionsSync & tests
HonzaCuhel Mar 26, 2025
5ad4b2e
Fix input, output
HonzaCuhel Mar 26, 2025
e9c06d4
Fix tests.
kkeroo Mar 27, 2025
608c402
Pre commit fix.
kkeroo Mar 27, 2025
e8ba945
Update DetectionsRecognitions
HonzaCuhel Apr 3, 2025
ca9012b
Merge branch 'main' into feat/add-detections-recognitions-sync
HonzaCuhel Apr 3, 2025
5fe4342
Rename property nn_data -> recognitions_data
HonzaCuhel Apr 4, 2025
65a7a41
Format code
HonzaCuhel Apr 4, 2025
2de2923
Import update.
kkeroo Apr 4, 2025
d4db905
_node appendix.
kkeroo Apr 4, 2025
9e7a4e5
Rename DetectionsRecognitionsSync -> TwoStageSync & refactor
HonzaCuhel Apr 5, 2025
6bf6c23
remove the private property docs
dominik737 Apr 9, 2025
9d7a57b
rename to GatherData
dominik737 Apr 9, 2025
4bc023a
file rename to gather data
dominik737 Apr 9, 2025
3f52be2
move fps validation to set camera fps
dominik737 Apr 9, 2025
f00a4ef
camera fps changed to optional
dominik737 Apr 9, 2025
bd078c0
rename detected recogntions to gathered data
dominik737 Apr 9, 2025
4acb1c8
rename gathered data file
dominik737 Apr 9, 2025
29a5ea9
generic gather data
dominik737 Apr 9, 2025
bf70393
rename collected data to gathered data
dominik737 Apr 9, 2025
8560b78
gather data variables renamed accordingly to more general use-case
dominik737 Apr 9, 2025
fa34049
gather data changeable wait for count function
dominik737 Apr 9, 2025
0674909
test build method
dominik737 Apr 10, 2025
fb5e4ca
rename test file
dominik737 Apr 10, 2025
932895a
queue get all method
dominik737 Apr 10, 2025
a735dcd
test img detections
dominik737 Apr 10, 2025
3edfdba
fix the queue wrong order from lifo to fifo
dominik737 Apr 11, 2025
c6338f7
test img detections
dominik737 Apr 11, 2025
a39643b
test set wait count fn
dominik737 Apr 11, 2025
f8d886e
test run without build
dominik737 Apr 11, 2025
8409931
gather data set timestamp based on reference data
dominik737 Apr 11, 2025
b963728
verify gather datata timestamp correctly set
dominik737 Apr 11, 2025
e530518
test clear old data
dominik737 Apr 11, 2025
65f2be4
remove unused image detections extended
dominik737 Apr 11, 2025
f0c4506
refactoring test gather data
dominik737 Apr 11, 2025
eb52e05
Merge branch 'main' into feat/gather-data
dominik737 Apr 11, 2025
1828d80
old references renamed
dominik737 Apr 14, 2025
527e665
wait count fn settable in build method
dominik737 Apr 14, 2025
2483cd7
fix python 3.8 default arg issue
dominik737 Apr 14, 2025
9e64838
lambdas instead of get all in the test
dominik737 Apr 14, 2025
6f0509d
add node suffix to the test file
dominik737 Apr 14, 2025
cb7ffa3
build method docstrings
dominik737 Apr 14, 2025
fa70ae3
gather data docstring update
dominik737 Apr 14, 2025
270a49e
update messsage docstrings
dominik737 Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions depthai_nodes/message/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .classification import Classifications
from .clusters import Cluster, Clusters
from .gathered_data import GatheredData
from .img_detections import (
ImgDetectionExtended,
ImgDetectionsExtended,
Expand All @@ -24,4 +25,5 @@
"Cluster",
"Prediction",
"Predictions",
"GatheredData",
]
67 changes: 67 additions & 0 deletions depthai_nodes/message/gathered_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Generic, List, TypeVar

import depthai as dai

TReference = TypeVar("TReference", bound=dai.Buffer)
TGathered = TypeVar("TGathered")


class GatheredData(dai.Buffer, Generic[TReference, TGathered]):
"""A class for gathered number of data and the reference data on which the data was
gathered.
Attributes
----------
reference_data: TReference
Data that is used to determine how many of TGathered to gather.
collected: List[TGathered]
List of collected data.
"""

def __init__(self, reference_data: TReference, gathered: List[TGathered]) -> None:
"""Initializes the DetectedRecognitions object."""
super().__init__()
self.reference_data = reference_data
self.gathered = gathered

@property
def reference_data(self) -> TReference:
"""Returns the reference data.
@return: Reference data.
@rtype: TReference
"""
return self._reference_data

@reference_data.setter
def reference_data(self, value: TReference):
"""Sets the reference data.
@param value: Reference data.
@type value: TReference
"""
self.setSequenceNum(value.getSequenceNum())
self.setTimestamp(value.getTimestamp())
self.setTimestampDevice(value.getTimestampDevice())
self._reference_data = value

@property
def gathered(self) -> List[TGathered]:
"""Returns the collected data.
@return: List of collected data.
@rtype: List[TGathered]
"""
return self._gathered

@gathered.setter
def gathered(self, value: List[TGathered]):
"""Sets the gathered data.
@param value: List of gathered data.
@type value: List[TGathered]
@raise TypeError: If value is not a list.
"""
if not isinstance(value, list):
raise TypeError("gathered_data must be a list.")
self._gathered = value
2 changes: 2 additions & 0 deletions depthai_nodes/node/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .apply_colormap import ApplyColormap
from .depth_merger import DepthMerger
from .gather_data import GatherData
from .host_spatials_calc import HostSpatialsCalc
from .img_detections_bridge import ImgDetectionsBridge
from .img_detections_filter import ImgDetectionsFilter
Expand Down Expand Up @@ -60,6 +61,7 @@
"BaseParser",
"DetectionParser",
"EmbeddingsParser",
"GatherData",
"ImgFrameOverlay",
"ImgDetectionsBridge",
"ImgDetectionsFilter",
Expand Down
209 changes: 209 additions & 0 deletions depthai_nodes/node/gather_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import time
from queue import PriorityQueue
from typing import (
Callable,
Dict,
Generic,
List,
Optional,
Protocol,
TypeVar,
runtime_checkable,
)

import depthai as dai

from depthai_nodes import GatheredData


@runtime_checkable
class HasDetections(Protocol):
@property
def detections(self) -> List:
...


TReference = TypeVar("TReference", bound=dai.Buffer)
TGathered = TypeVar("TGathered", bound=dai.Buffer)


class GatherData(dai.node.ThreadedHostNode, Generic[TReference, TGathered]):
FPS_TOLERANCE_DIVISOR = 2.0
INPUT_CHECKS_PER_FPS = 100
"""A class for gathering data. By default gathers n of dai.NNData where n is number
of dai.ImgDetection objects in dai.ImgDetections.
Attributes
----------
FPS_TOLERANCE_DIVISOR: float
Divisor for the FPS tolerance.
INPUT_CHECKS_PER_FPS: int
Number of input checks per FPS.
input_data: dai.Node.Input
Input to be gathered.
input_reference: dai.Node.Input
Input to determine how many gathered items to wait for.
output: dai.Node.Output
Output for gathered data.
"""

def __init__(self) -> None:
"""Initializes the GatherData node."""
self._camera_fps: Optional[int] = None
self._unmatched_data: List[TGathered] = []
self._data_by_reference_ts: Dict[float, List[TGathered]] = {}
self._reference_data: Dict[float, TReference] = {}
self._ready_timestamps = PriorityQueue()
self._wait_count_fn = self._default_wait_count_fn

self.input_data = self.createInput()
self.input_reference = self.createInput()
self.out = self.createOutput()

def _default_wait_count_fn(self, reference: TReference) -> int:
assert isinstance(reference, HasDetections)
return len(reference.detections)

def build(self, camera_fps: int) -> "GatherData[TReference, TGathered]":
self.set_camera_fps(camera_fps)
return self

def set_camera_fps(self, fps: int) -> None:
if fps <= 0:
raise ValueError(f"Camera FPS must be positive, got {fps}")
self._camera_fps = fps

def run(self) -> None:
if not self._camera_fps:
raise ValueError("Camera FPS not set. Call build() before run().")

while self.isRunning():
try:
input_data: TGathered = self.input_data.tryGet()
input_reference: TReference = self.input_reference.tryGet()
except dai.MessageQueue.QueueException:
break
if input_data:
self._add_data(input_data)
self._send_ready_data()
if input_reference:
self._add_reference(input_reference)
self._send_ready_data()

time.sleep(1 / self.INPUT_CHECKS_PER_FPS / self._camera_fps)

def _send_ready_data(self) -> None:
ready_data = self._pop_ready_data()
if ready_data:
self._clear_old_data(ready_data)
self.out.send(ready_data)

def _add_data(self, data: TGathered) -> None:
data_ts = self._get_total_seconds_ts(data)
best_matching_reference_ts = self._get_matching_reference_ts(data_ts)

if best_matching_reference_ts is not None:
self._add_data_by_reference_ts(data, best_matching_reference_ts)
self._update_ready_timestamps(best_matching_reference_ts)
else:
self._unmatched_data.append(data)

def _get_matching_reference_ts(self, data_ts: float) -> Optional[float]:
for reference_ts in self._reference_data.keys():
if self._timestamps_in_tolerance(reference_ts, data_ts):
return reference_ts
return None

def _add_reference(
self,
reference: TReference,
) -> None:
reference_ts = self._get_total_seconds_ts(reference)
self._reference_data[reference_ts] = reference
self._try_match_data(reference_ts)
self._update_ready_timestamps(reference_ts)

def _try_match_data(self, reference_ts: float) -> None:
matched_data: List[TGathered] = []
for data in self._unmatched_data:
data_ts = self._get_total_seconds_ts(data)
if self._timestamps_in_tolerance(reference_ts, data_ts):
self._add_data_by_reference_ts(data, reference_ts)
matched_data.append(data)

for matched in matched_data:
self._unmatched_data.remove(matched)

def _timestamps_in_tolerance(self, timestamp1: float, timestamp2: float) -> bool:
difference = abs(timestamp1 - timestamp2)
return difference < (1 / self._camera_fps / self.FPS_TOLERANCE_DIVISOR)

def _add_data_by_reference_ts(self, data: TGathered, reference_ts: float) -> None:
if reference_ts in self._data_by_reference_ts:
self._data_by_reference_ts[reference_ts].append(data)
else:
self._data_by_reference_ts[reference_ts] = [data]

def _update_ready_timestamps(self, timestamp: float) -> None:
if not self._timestamp_ready(timestamp):
return

self._ready_timestamps.put(timestamp)

def _timestamp_ready(self, timestamp: float) -> bool:
reference = self._reference_data.get(timestamp)
if not reference:
return False

wait_for_count = self._get_wait_count(reference)
if wait_for_count == 0:
return True

recognitions = self._data_by_reference_ts.get(timestamp)
if not recognitions:
return False

return wait_for_count == len(recognitions)

def _get_wait_count(self, reference: TReference) -> int:
return self._wait_count_fn(reference)

def _pop_ready_data(self) -> Optional[GatheredData]:
if self._ready_timestamps.empty():
return None

timestamp = self._ready_timestamps.get()
return GatheredData(
reference_data=self._reference_data.pop(timestamp),
gathered=self._data_by_reference_ts.pop(timestamp, None) or [],
)

def _clear_old_data(self, ready_data: GatheredData) -> None:
current_timestamp = self._get_total_seconds_ts(ready_data)
self._clear_unmatched_data(current_timestamp)
self._clear_old_references(current_timestamp)

def _clear_unmatched_data(self, current_timestamp: float) -> None:
unmatched_data_to_remove = []
for unmatched_data in self._unmatched_data:
if self._get_total_seconds_ts(unmatched_data) < current_timestamp:
unmatched_data_to_remove.append(unmatched_data)

for unmatched_data in unmatched_data_to_remove:
self._unmatched_data.remove(unmatched_data)

def _get_total_seconds_ts(self, buffer_like: dai.Buffer) -> float:
return buffer_like.getTimestamp().total_seconds()

def _clear_old_references(self, current_timestamp: float) -> None:
reference_keys_to_pop = []
for reference_ts in self._reference_data.keys():
if reference_ts < current_timestamp:
reference_keys_to_pop.append(reference_ts)

for reference_ts in reference_keys_to_pop:
self._reference_data.pop(reference_ts)
self._data_by_reference_ts.pop(reference_ts, None)

def set_wait_count_fn(self, fn: Callable[[TReference], int]) -> None:
self._wait_count_fn = fn
2 changes: 1 addition & 1 deletion tests/stability_tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ You dont need to manually create the mock classes since they are implemented in

Check the already implemented tests for reference.

- `ThreadedHostNode` - tests for `DetectionsRecognitionsSync`
- `ThreadedHostNode` - tests for `TwoStageSync`
- `HostNode` - tests for `TilesPatcher`

You can check if everything works by running the tests locally. To run the unit tests move to the `depthai_nodes/tests` directory and run the tests with `pytest`.
Expand Down
Loading