-
Notifications
You must be signed in to change notification settings - Fork 2
Feat: Gather Data #203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Feat: Gather Data #203
Changes from 40 commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
d1720a2
Add DetectionsRecognitionsSync
HonzaCuhel e015fba
Update typings to support py3.8
HonzaCuhel fce76e0
Update DetectionsRecognitionsSync & tests
HonzaCuhel 5ad4b2e
Fix input, output
HonzaCuhel e9c06d4
Fix tests.
kkeroo 608c402
Pre commit fix.
kkeroo e8ba945
Update DetectionsRecognitions
HonzaCuhel ca9012b
Merge branch 'main' into feat/add-detections-recognitions-sync
HonzaCuhel 5fe4342
Rename property nn_data -> recognitions_data
HonzaCuhel 65a7a41
Format code
HonzaCuhel 2de2923
Import update.
kkeroo d4db905
_node appendix.
kkeroo 9e7a4e5
Rename DetectionsRecognitionsSync -> TwoStageSync & refactor
HonzaCuhel 6bf6c23
remove the private property docs
dominik737 9d7a57b
rename to GatherData
dominik737 4bc023a
file rename to gather data
dominik737 3f52be2
move fps validation to set camera fps
dominik737 f00a4ef
camera fps changed to optional
dominik737 bd078c0
rename detected recogntions to gathered data
dominik737 4acb1c8
rename gathered data file
dominik737 29a5ea9
generic gather data
dominik737 bf70393
rename collected data to gathered data
dominik737 8560b78
gather data variables renamed accordingly to more general use-case
dominik737 fa34049
gather data changeable wait for count function
dominik737 0674909
test build method
dominik737 fb5e4ca
rename test file
dominik737 932895a
queue get all method
dominik737 a735dcd
test img detections
dominik737 3edfdba
fix the queue wrong order from lifo to fifo
dominik737 c6338f7
test img detections
dominik737 a39643b
test set wait count fn
dominik737 f8d886e
test run without build
dominik737 8409931
gather data set timestamp based on reference data
dominik737 b963728
verify gather datata timestamp correctly set
dominik737 e530518
test clear old data
dominik737 65f2be4
remove unused image detections extended
dominik737 f0c4506
refactoring test gather data
dominik737 eb52e05
Merge branch 'main' into feat/gather-data
dominik737 1828d80
old references renamed
dominik737 527e665
wait count fn settable in build method
dominik737 2483cd7
fix python 3.8 default arg issue
dominik737 9e64838
lambdas instead of get all in the test
dominik737 6f0509d
add node suffix to the test file
dominik737 cb7ffa3
build method docstrings
dominik737 fa70ae3
gather data docstring update
dominik737 270a49e
update messsage docstrings
dominik737 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| from typing import Generic, List, TypeVar | ||
|
|
||
| import depthai as dai | ||
|
|
||
| TReference = TypeVar("TReference", bound=dai.Buffer) | ||
| TGathered = TypeVar("TGathered") | ||
|
|
||
|
|
||
| class GatheredData(dai.Buffer, Generic[TReference, TGathered]): | ||
| """A class for gathered number of data and the reference data on which the data was | ||
| gathered. | ||
| Attributes | ||
| ---------- | ||
| reference_data: TReference | ||
| Data that is used to determine how many of TGathered to gather. | ||
| collected: List[TGathered] | ||
| List of collected data. | ||
| """ | ||
|
|
||
| def __init__(self, reference_data: TReference, gathered: List[TGathered]) -> None: | ||
| """Initializes the GatheredData object.""" | ||
| super().__init__() | ||
| self.reference_data = reference_data | ||
| self.gathered = gathered | ||
|
|
||
| @property | ||
| def reference_data(self) -> TReference: | ||
| """Returns the reference data. | ||
| @return: Reference data. | ||
| @rtype: TReference | ||
| """ | ||
| return self._reference_data | ||
jkbmrz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @reference_data.setter | ||
| def reference_data(self, value: TReference): | ||
| """Sets the reference data. | ||
| @param value: Reference data. | ||
| @type value: TReference | ||
| """ | ||
| self.setSequenceNum(value.getSequenceNum()) | ||
| self.setTimestamp(value.getTimestamp()) | ||
| self.setTimestampDevice(value.getTimestampDevice()) | ||
| self._reference_data = value | ||
|
|
||
| @property | ||
| def gathered(self) -> List[TGathered]: | ||
| """Returns the collected data. | ||
| @return: List of collected data. | ||
| @rtype: List[TGathered] | ||
| """ | ||
| return self._gathered | ||
jkbmrz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @gathered.setter | ||
| def gathered(self, value: List[TGathered]): | ||
| """Sets the gathered data. | ||
| @param value: List of gathered data. | ||
| @type value: List[TGathered] | ||
| @raise TypeError: If value is not a list. | ||
| """ | ||
| if not isinstance(value, list): | ||
| raise TypeError("gathered_data must be a list.") | ||
| self._gathered = value | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,215 @@ | ||
| import time | ||
| from queue import PriorityQueue | ||
| from typing import ( | ||
| Callable, | ||
| Dict, | ||
| Generic, | ||
| List, | ||
| Optional, | ||
| Protocol, | ||
| TypeVar, | ||
| runtime_checkable, | ||
| ) | ||
|
|
||
| import depthai as dai | ||
|
|
||
| from depthai_nodes import GatheredData | ||
|
|
||
|
|
||
| @runtime_checkable | ||
| class HasDetections(Protocol): | ||
| @property | ||
| def detections(self) -> List: | ||
| ... | ||
|
|
||
|
|
||
| TReference = TypeVar("TReference", bound=dai.Buffer) | ||
| TGathered = TypeVar("TGathered", bound=dai.Buffer) | ||
|
|
||
|
|
||
| class GatherData(dai.node.ThreadedHostNode, Generic[TReference, TGathered]): | ||
| FPS_TOLERANCE_DIVISOR = 2.0 | ||
| INPUT_CHECKS_PER_FPS = 100 | ||
| """A class for gathering data. By default gathers n of dai.NNData where n is number | ||
| of dai.ImgDetection objects in dai.ImgDetections. | ||
dominik737 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Attributes | ||
| ---------- | ||
| FPS_TOLERANCE_DIVISOR: float | ||
| Divisor for the FPS tolerance. | ||
| INPUT_CHECKS_PER_FPS: int | ||
| Number of input checks per FPS. | ||
| input_data: dai.Node.Input | ||
| Input to be gathered. | ||
| input_reference: dai.Node.Input | ||
| Input to determine how many gathered items to wait for. | ||
| output: dai.Node.Output | ||
| Output for gathered data. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| """Initializes the GatherData node.""" | ||
| self._camera_fps: Optional[int] = None | ||
| self._unmatched_data: List[TGathered] = [] | ||
| self._data_by_reference_ts: Dict[float, List[TGathered]] = {} | ||
| self._reference_data: Dict[float, TReference] = {} | ||
| self._ready_timestamps = PriorityQueue() | ||
| self._wait_count_fn = self._default_wait_count_fn | ||
dominik737 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| self.input_data = self.createInput() | ||
| self.input_reference = self.createInput() | ||
| self.out = self.createOutput() | ||
|
|
||
| @staticmethod | ||
| def _default_wait_count_fn(reference: TReference) -> int: | ||
| assert isinstance(reference, HasDetections) | ||
| return len(reference.detections) | ||
|
|
||
| def build( | ||
| self, | ||
| camera_fps: int, | ||
| wait_count_fn: Callable[[TReference], int] = _default_wait_count_fn, | ||
| ) -> "GatherData[TReference, TGathered]": | ||
dominik737 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.set_camera_fps(camera_fps) | ||
| self.set_wait_count_fn(wait_count_fn) | ||
| return self | ||
|
|
||
| def set_camera_fps(self, fps: int) -> None: | ||
| if fps <= 0: | ||
| raise ValueError(f"Camera FPS must be positive, got {fps}") | ||
| self._camera_fps = fps | ||
|
|
||
| def run(self) -> None: | ||
| if not self._camera_fps: | ||
| raise ValueError("Camera FPS not set. Call build() before run().") | ||
|
|
||
| while self.isRunning(): | ||
| try: | ||
| input_data: TGathered = self.input_data.tryGet() | ||
| input_reference: TReference = self.input_reference.tryGet() | ||
| except dai.MessageQueue.QueueException: | ||
| break | ||
| if input_data: | ||
| self._add_data(input_data) | ||
| self._send_ready_data() | ||
| if input_reference: | ||
| self._add_reference(input_reference) | ||
| self._send_ready_data() | ||
|
|
||
| time.sleep(1 / self.INPUT_CHECKS_PER_FPS / self._camera_fps) | ||
|
|
||
| def _send_ready_data(self) -> None: | ||
| ready_data = self._pop_ready_data() | ||
| if ready_data: | ||
| self._clear_old_data(ready_data) | ||
| self.out.send(ready_data) | ||
|
|
||
| def _add_data(self, data: TGathered) -> None: | ||
| data_ts = self._get_total_seconds_ts(data) | ||
| best_matching_reference_ts = self._get_matching_reference_ts(data_ts) | ||
|
|
||
| if best_matching_reference_ts is not None: | ||
| self._add_data_by_reference_ts(data, best_matching_reference_ts) | ||
| self._update_ready_timestamps(best_matching_reference_ts) | ||
| else: | ||
| self._unmatched_data.append(data) | ||
|
|
||
| def _get_matching_reference_ts(self, data_ts: float) -> Optional[float]: | ||
| for reference_ts in self._reference_data.keys(): | ||
| if self._timestamps_in_tolerance(reference_ts, data_ts): | ||
| return reference_ts | ||
| return None | ||
|
|
||
| def _add_reference( | ||
| self, | ||
| reference: TReference, | ||
| ) -> None: | ||
| reference_ts = self._get_total_seconds_ts(reference) | ||
| self._reference_data[reference_ts] = reference | ||
| self._try_match_data(reference_ts) | ||
| self._update_ready_timestamps(reference_ts) | ||
|
|
||
| def _try_match_data(self, reference_ts: float) -> None: | ||
| matched_data: List[TGathered] = [] | ||
| for data in self._unmatched_data: | ||
| data_ts = self._get_total_seconds_ts(data) | ||
| if self._timestamps_in_tolerance(reference_ts, data_ts): | ||
| self._add_data_by_reference_ts(data, reference_ts) | ||
| matched_data.append(data) | ||
|
|
||
| for matched in matched_data: | ||
| self._unmatched_data.remove(matched) | ||
|
|
||
| def _timestamps_in_tolerance(self, timestamp1: float, timestamp2: float) -> bool: | ||
| difference = abs(timestamp1 - timestamp2) | ||
| return difference < (1 / self._camera_fps / self.FPS_TOLERANCE_DIVISOR) | ||
|
|
||
| def _add_data_by_reference_ts(self, data: TGathered, reference_ts: float) -> None: | ||
| if reference_ts in self._data_by_reference_ts: | ||
| self._data_by_reference_ts[reference_ts].append(data) | ||
| else: | ||
| self._data_by_reference_ts[reference_ts] = [data] | ||
|
|
||
| def _update_ready_timestamps(self, timestamp: float) -> None: | ||
| if not self._timestamp_ready(timestamp): | ||
| return | ||
|
|
||
| self._ready_timestamps.put(timestamp) | ||
|
|
||
| def _timestamp_ready(self, timestamp: float) -> bool: | ||
| reference = self._reference_data.get(timestamp) | ||
| if not reference: | ||
| return False | ||
|
|
||
| wait_for_count = self._get_wait_count(reference) | ||
| if wait_for_count == 0: | ||
| return True | ||
|
|
||
| recognitions = self._data_by_reference_ts.get(timestamp) | ||
| if not recognitions: | ||
| return False | ||
|
|
||
| return wait_for_count == len(recognitions) | ||
|
|
||
| def _get_wait_count(self, reference: TReference) -> int: | ||
| return self._wait_count_fn(reference) | ||
|
|
||
| def _pop_ready_data(self) -> Optional[GatheredData]: | ||
| if self._ready_timestamps.empty(): | ||
| return None | ||
|
|
||
| timestamp = self._ready_timestamps.get() | ||
| return GatheredData( | ||
| reference_data=self._reference_data.pop(timestamp), | ||
| gathered=self._data_by_reference_ts.pop(timestamp, None) or [], | ||
| ) | ||
|
|
||
| def _clear_old_data(self, ready_data: GatheredData) -> None: | ||
| current_timestamp = self._get_total_seconds_ts(ready_data) | ||
| self._clear_unmatched_data(current_timestamp) | ||
| self._clear_old_references(current_timestamp) | ||
|
|
||
| def _clear_unmatched_data(self, current_timestamp: float) -> None: | ||
| unmatched_data_to_remove = [] | ||
| for unmatched_data in self._unmatched_data: | ||
| if self._get_total_seconds_ts(unmatched_data) < current_timestamp: | ||
| unmatched_data_to_remove.append(unmatched_data) | ||
|
|
||
| for unmatched_data in unmatched_data_to_remove: | ||
| self._unmatched_data.remove(unmatched_data) | ||
|
|
||
| def _get_total_seconds_ts(self, buffer_like: dai.Buffer) -> float: | ||
| return buffer_like.getTimestamp().total_seconds() | ||
|
|
||
| def _clear_old_references(self, current_timestamp: float) -> None: | ||
| reference_keys_to_pop = [] | ||
| for reference_ts in self._reference_data.keys(): | ||
| if reference_ts < current_timestamp: | ||
| reference_keys_to_pop.append(reference_ts) | ||
|
|
||
| for reference_ts in reference_keys_to_pop: | ||
| self._reference_data.pop(reference_ts) | ||
| self._data_by_reference_ts.pop(reference_ts, None) | ||
|
|
||
| def set_wait_count_fn(self, fn: Callable[[TReference], int]) -> None: | ||
| self._wait_count_fn = fn | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.