-
Notifications
You must be signed in to change notification settings - Fork 2
ExtendedNeuralNetwork node #249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7762cd9
37931a9
ba15575
86cde6e
9aca0f8
781bfb4
52a6389
bd773f0
baf57dc
d092bfd
c4aa2ea
c801a24
43fc081
1b7393c
ba36a48
8ceb2e7
480a6f9
a1e357c
5d76902
ad0a92f
217823a
c51c2e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import depthai as dai | ||
|
|
||
| from depthai_nodes.node.base_host_node import BaseHostNode | ||
| from depthai_nodes.node.utils.message_remapping import remap_message | ||
|
|
||
|
|
||
| class DetectionsMapper(BaseHostNode): | ||
| """Remap NN detection messages to ImgFrame coordinates.""" | ||
|
|
||
| SCRIPT_CONTENT = """ | ||
| # Strip ImgFrame image data and send only ImgTransformation | ||
| # Reduces the amount of date being sent between host and device | ||
|
|
||
| try: | ||
| while True: | ||
| frame = node.inputs['preview'].get() | ||
| transformation = frame.getTransformation() | ||
| empty_frame = ImgFrame() | ||
| empty_frame.setTransformation(transformation) | ||
| empty_frame.setTimestamp(frame.getTimestamp()) | ||
| empty_frame.setTimestampDevice(frame.getTimestampDevice()) | ||
| node.outputs['transformation'].send(empty_frame) | ||
| except Exception as e: | ||
| node.warn(str(e)) | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self._pipeline = self.getParentPipeline() | ||
| if self._pipeline.getDefaultDevice().getPlatform() == dai.Platform.RVC2: | ||
| raise RuntimeError( | ||
| "DetectionsMapper node is currently not supported on RVC2." | ||
| ) | ||
| self._logger.debug("Creating Script node") | ||
| self._script = self._pipeline.create(dai.node.Script) | ||
| self._script.setScript(self.SCRIPT_CONTENT) | ||
| self._logger.debug("ImgDetectionsMapper initialized") | ||
|
|
||
| def build( | ||
| self, img_input: dai.Node.Output, nn_input: dai.Node.Output | ||
| ) -> "DetectionsMapper": | ||
| img_input.link(self._script.inputs["preview"]) | ||
| self._script.outputs["transformation"].setPossibleDatatypes( | ||
| [(dai.DatatypeEnum.ImgFrame, True)] | ||
| ) | ||
| self.link_args(self._script.outputs["transformation"], nn_input) | ||
| self._logger.debug("ImgDetectionsMapper built") | ||
| return self | ||
|
|
||
| def process(self, img: dai.ImgFrame, nn: dai.Buffer) -> None: | ||
| try: | ||
| nn_trans = nn.getTransformation() | ||
| except Exception as e: | ||
| raise RuntimeError( | ||
| "Could not get transformation from received message." | ||
| ) from e | ||
| if nn_trans is None: | ||
| raise RuntimeError("Received detection message without transformation") | ||
| message = remap_message(nn_trans, img.getTransformation(), nn) | ||
| message.setTimestamp(nn.getTimestamp()) | ||
| message.setTimestampDevice(nn.getTimestampDevice()) | ||
| message.setSequenceNum(nn.getSequenceNum()) | ||
| message.setTransformation(img.getTransformation()) | ||
| self.out.send(message) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. General comment for entire PR: Use logger to log the stages of the class. Something like
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree, check the parsers implementation for reference. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,322 @@ | ||
| from typing import List, Literal, Optional, Tuple, Union, overload | ||
|
|
||
| import depthai as dai | ||
| import numpy as np | ||
|
|
||
| from depthai_nodes.logging import get_logger | ||
| from depthai_nodes.node import ( | ||
| ImgDetectionsFilter, | ||
| ParsingNeuralNetwork, | ||
| TilesPatcher, | ||
| Tiling, | ||
| ) | ||
| from depthai_nodes.node.detections_mapper import DetectionsMapper | ||
|
|
||
|
|
||
| class ExtendedNeuralNetwork(dai.node.ThreadedHostNode): | ||
| """Node that wraps the ParsingNeuralNetwork node and adds following capabilities: | ||
| 1. Automatic input resizing to the neural network input size. | ||
| 2. Remapping of detection coordinates from neural network output to input frame coordinates. | ||
| 3. Neural network output filtering based on confidence threshold and labels. | ||
| (Only supported for ImgDetectionsExtended and ImgDetections messages). | ||
| 4. Input tiling. | ||
|
|
||
| Supports only single head models. | ||
|
|
||
| Attributes | ||
| ---------- | ||
| out : Node.Output | ||
| Neural network output. Detections are remapped to the input frame coordinates. | ||
| nn_passthrough : Node.Output | ||
| Neural network frame passthrough. | ||
| """ | ||
|
|
||
| IMG_FRAME_TYPES = { | ||
| dai.Platform.RVC2: dai.ImgFrame.Type.BGR888p, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to add the frame type as RVC2 is not supported either way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is RVC2 not supported?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a check for RVC2 here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| dai.Platform.RVC4: dai.ImgFrame.Type.BGR888i, | ||
| } | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| self._platform = self.getParentPipeline().getDefaultDevice().getPlatform() | ||
| try: | ||
| self._img_frame_type = self.IMG_FRAME_TYPES[self._platform] | ||
| except KeyError as e: | ||
| raise ValueError( | ||
| f"No dai.ImgFrame.Type defined for platform {self._platform}." | ||
| ) from e | ||
|
|
||
| if self._platform == dai.Platform.RVC2: | ||
| raise RuntimeError( | ||
| "ExtendedNeuralNetwork node is currently not supported on RVC2." | ||
| ) | ||
|
|
||
| self._logger = get_logger(self.__class__.__name__) | ||
|
|
||
| self._confidence_threshold = None | ||
| self._labels_to_keep = None | ||
| self._labels_to_reject = None | ||
| self._max_detections = None | ||
| self._tiling_grid_size = (2, 2) | ||
| self._tiling_overlap = 0.1 | ||
| self._tiling_global_detection = False | ||
| self._tiling_grid_matrix = None | ||
| self._tiling_iou_threshold = 0.2 | ||
|
|
||
| self._pipeline = self.getParentPipeline() | ||
| self.nn: Optional[ParsingNeuralNetwork] = None | ||
| self.tiling: Optional[Tiling] = None | ||
| self.patcher: Optional[TilesPatcher] = None | ||
| self.detections_filter: Optional[ImgDetectionsFilter] = None | ||
| self.nn_resize: Optional[dai.node.ImageManip] = None | ||
| self.img_detections_mapper: Optional[DetectionsMapper] = None | ||
| self._out: Optional[dai.Node.Output] = None | ||
|
|
||
| @overload | ||
| def build( | ||
| self, | ||
| input: dai.Node.Output, | ||
| nn_source: Union[dai.NNModelDescription, dai.NNArchive, str], | ||
| input_resize_mode: dai.ImageManipConfig.ResizeMode, | ||
| enable_tiling: Literal[False] = False, | ||
| input_size: None = None, | ||
| enable_detection_filtering: bool = False, | ||
| ) -> "ExtendedNeuralNetwork": | ||
| ... | ||
|
|
||
| @overload | ||
| def build( | ||
| self, | ||
| input: dai.Node.Output, | ||
| nn_source: Union[dai.NNModelDescription, dai.NNArchive, str], | ||
| input_resize_mode: dai.ImageManipConfig.ResizeMode, | ||
| enable_tiling: Literal[True], | ||
| input_size: Tuple[int, int], | ||
| enable_detection_filtering: bool = False, | ||
| ) -> "ExtendedNeuralNetwork": | ||
| ... | ||
|
|
||
| def build( | ||
| self, | ||
| input: dai.Node.Output, | ||
| nn_source: Union[dai.NNModelDescription, dai.NNArchive, str], | ||
| input_resize_mode: dai.ImageManipConfig.ResizeMode, | ||
aljazkonec1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| enable_tiling: bool = False, | ||
| input_size: Optional[Tuple[int, int]] = None, | ||
aljazkonec1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| enable_detection_filtering: bool = False, | ||
| ) -> "ExtendedNeuralNetwork": | ||
| """Builds the underlying nodes. | ||
|
|
||
| @param input: ImgFrame node's input. Frames are automatically resized to fit the neural network input size. | ||
| @type input: Node.Input | ||
| @param nn_source: NNModelDescription object containing the HubAI model descriptors, NNArchive object of the model, or HubAI model slug in form of <model_slug>:<model_version_slug> or <model_slug>:<model_version_slug>:<model_instance_hash>. | ||
| @type nn_source: Union[dai.NNModelDescription, dai.NNArchive, str] | ||
| @param input_resize_mode: Resize mode for the neural network input. | ||
| @type input_resize_mode: dai.ImageManipConfig.ResizeMode | ||
| @param enable_tiling: If True, enables tiling. | ||
| @type enable_tiling: bool | ||
| @param input_size: ImgFrame input size for tiling. Must be provided if tiling is enabled. | ||
| @type input_size: Tuple[int, int] | ||
| @param enable_detection_filtering: If True, enables detection filtering based on labels and confidence threshold | ||
| (only supported for ImgDetectionsExtended and ImgDetections messages). | ||
| @type enable_detection_filtering: bool | ||
| @return: Returns the ExtendedNeuralNetwork object. | ||
| @rtype: ExtendedNeuralNetwork | ||
| @raise ValueError: If tiling is enabled and input_size is not provided. | ||
| @raise ValueError: If NNArchive does not contain input size. | ||
| """ | ||
| if input_size is not None and any([i <= 0 for i in input_size]): | ||
| raise ValueError("Input size must be positive") | ||
| if enable_tiling: | ||
| if input_size is None: | ||
| raise ValueError("Input size must be provided for tiling") | ||
| nn_out = self._createTilingPipeline( | ||
| input, | ||
| input_size, | ||
| input_resize_mode, | ||
| nn_source, | ||
| ) | ||
| else: | ||
| nn_out = self._createBasicPipeline(input, input_resize_mode, nn_source) | ||
| if enable_detection_filtering: | ||
| self.detections_filter = self._pipeline.create(ImgDetectionsFilter).build( | ||
| nn_out, | ||
| labels_to_keep=self._labels_to_keep, # type: ignore | ||
| labels_to_reject=self._labels_to_reject, # type: ignore | ||
| confidence_threshold=self._confidence_threshold, | ||
| max_detections=self._max_detections, # type: ignore | ||
| ) | ||
| self._out = self.detections_filter.out | ||
| else: | ||
| self.detections_filter = None | ||
| self._out = nn_out | ||
| self._logger.debug("ExtendedNeuralNetwork built") | ||
| return self | ||
|
|
||
| def run(self): | ||
| pass | ||
|
|
||
| def _createBasicPipeline( | ||
| self, | ||
| input: dai.Node.Output, | ||
| input_resize_mode: dai.ImageManipConfig.ResizeMode, | ||
| nn_source: Union[dai.NNModelDescription, dai.NNArchive, str], | ||
| ): | ||
| """Create inner nodes, when tiling is disabled.""" | ||
|
|
||
| self._logger.debug("Creating basic pipeline") | ||
| self._logger.debug("Creating ImageManip node for resizing NN input") | ||
| self.nn_resize = self._pipeline.create(dai.node.ImageManip) | ||
| input.link(self.nn_resize.inputImage) | ||
| self._logger.debug("Building ParsingNeuralNetwork") | ||
| self.nn = self._pipeline.create(ParsingNeuralNetwork).build( | ||
| self.nn_resize.out, nn_source | ||
| ) | ||
| if self.nn._getModelHeadsLen() != 1: | ||
| raise RuntimeError( | ||
| f"ExtendedNeuralNetwork only supports single head models. The model has {self.nn._getModelHeadsLen()} heads." | ||
| ) | ||
| nn_w = self.nn._nn_archive.getInputWidth() | ||
| nn_h = self.nn._nn_archive.getInputHeight() | ||
| if nn_w is None or nn_h is None: | ||
| raise ValueError("NNArchive does not contain input size") | ||
| self.nn_resize.initialConfig.setOutputSize(nn_w, nn_h, input_resize_mode) | ||
| self.nn_resize.setMaxOutputFrameSize(nn_w * nn_h * 3) | ||
| self.nn_resize.initialConfig.setFrameType(self._img_frame_type) | ||
|
|
||
| self._logger.debug("Building DetectionsMapper") | ||
| self.img_detections_mapper = self._pipeline.create(DetectionsMapper).build( | ||
| input, self.nn.out | ||
| ) | ||
| return self.img_detections_mapper.out | ||
|
|
||
| def _createTilingPipeline( | ||
| self, | ||
| input: dai.Node.Output, | ||
| input_size: Tuple[int, int], | ||
| input_resize_mode: dai.ImageManipConfig.ResizeMode, | ||
| nn_source: Union[dai.NNModelDescription, dai.NNArchive, str], | ||
| ): | ||
| """Create inner nodes, when tiling is enabled.""" | ||
|
|
||
| self._logger.debug("Creating tiling pipeline") | ||
| self.tiling = self._pipeline.create(Tiling) | ||
| self._logger.debug("Building ParsingNeuralNetwork") | ||
| self.nn = self._pipeline.create(ParsingNeuralNetwork).build( | ||
| self.tiling.out, nn_source | ||
| ) | ||
| if self.nn._getModelHeadsLen() != 1: | ||
| raise RuntimeError( | ||
| f"ExtendedNeuralNetwork only supports single head models. The model has {self.nn._getModelHeadsLen()} heads." | ||
| ) | ||
| nn_size = self.nn._nn_archive.getInputSize() | ||
| if nn_size is None: | ||
| raise ValueError("NNArchive does not contain input size") | ||
| self._logger.debug("Building Tiling") | ||
| self.tiling.build( | ||
| img_output=input, | ||
| img_shape=input_size, | ||
| overlap=self._tiling_overlap, | ||
| grid_size=self._tiling_grid_size, | ||
| resize_mode=input_resize_mode, | ||
| global_detection=self._tiling_global_detection, | ||
| grid_matrix=self._tiling_grid_matrix, | ||
| nn_shape=nn_size, | ||
| ) | ||
| self.tiling.setFrameType(self._img_frame_type) | ||
| self._logger.debug("Building TilesPatcher") | ||
| self.patcher = self._pipeline.create(TilesPatcher).build( | ||
| img_frames=input, | ||
| nn=self.nn.out, | ||
| conf_thresh=self._confidence_threshold or 0.0, | ||
| iou_thresh=self._tiling_iou_threshold, | ||
| ) | ||
| return self.patcher.out | ||
|
|
||
| def setTilingGridSize(self, grid_size: Tuple[int, int]) -> None: | ||
| """Set grid size for tiling. | ||
|
|
||
| Only used if tiling is enabled. | ||
| """ | ||
|
|
||
| self._tiling_grid_size = grid_size | ||
| if self.tiling is not None: | ||
| self.tiling.setGridSize(grid_size) | ||
| self._logger.debug(f"Tiling grid size set to {self._tiling_grid_size}") | ||
|
|
||
| def setTilingOverlap(self, overlap: float) -> None: | ||
| """Set tile overlap. | ||
|
|
||
| Only used if tiling is enabled. | ||
| """ | ||
|
|
||
| self._tiling_overlap = overlap | ||
| if self.tiling is not None: | ||
| self.tiling.setOverlap(overlap) | ||
| self._logger.debug(f"Tiling overlap set to {self._tiling_overlap}") | ||
|
|
||
| def setTilingGlobalDetection(self, global_detection: bool) -> None: | ||
| """Set global detection flag for tiling. | ||
|
|
||
| Only used if tiling is enabled. | ||
| """ | ||
|
|
||
| self._tiling_global_detection = global_detection | ||
| if self.tiling is not None: | ||
| self.tiling.setGlobalDetection(global_detection) | ||
| self._logger.debug( | ||
| f"Tiling global detection set to {self._tiling_global_detection}" | ||
| ) | ||
|
|
||
| def setTilingGridMatrix(self, grid_matrix: Union[np.ndarray, List, None]) -> None: | ||
| """Set grid matrix for tiling. | ||
|
|
||
| Only used if tiling is enabled. | ||
| """ | ||
|
|
||
| self._tiling_grid_matrix = grid_matrix | ||
| if self.tiling is not None: | ||
| self.tiling.setGridMatrix(grid_matrix) | ||
| self._logger.debug(f"Tiling grid matrix set to {self._tiling_grid_matrix}") | ||
|
|
||
| def setLabels(self, labels: List[int] | None, keep: bool) -> None: | ||
| """Set labels to keep or reject.""" | ||
|
|
||
| if keep: | ||
| self._labels_to_keep = labels | ||
| else: | ||
| self._labels_to_reject = labels | ||
| if self.detections_filter is not None: | ||
| self.detections_filter.setLabels(labels, keep) # type: ignore | ||
| self._logger.debug(f"Labels set to {self._labels_to_keep}") | ||
|
|
||
| def setMaxDetections(self, max_detections: int) -> None: | ||
| """Set maximum number of detections to keep.""" | ||
|
|
||
| self._max_detections = max_detections | ||
| if self.detections_filter is not None: | ||
| self.detections_filter.setMaxDetections(max_detections) | ||
| self._logger.debug(f"Max detections set to {self._max_detections}") | ||
|
|
||
| def setConfidenceThreshold(self, confidence_threshold: float) -> None: | ||
| """Set confidence threshold.""" | ||
|
|
||
| self._confidence_threshold = confidence_threshold | ||
| if self.detections_filter is not None: | ||
| self.detections_filter.setConfidenceThreshold(confidence_threshold) | ||
| if self.patcher is not None: | ||
| self.patcher.setConfidenceThreshold(confidence_threshold) | ||
| self._logger.debug(f"Confidence threshold set to {self._confidence_threshold}") | ||
|
|
||
| @property | ||
| def out(self): | ||
| if self._out is None: | ||
| raise RuntimeError("ExtendedNeuralNetwork not initialized") | ||
| return self._out | ||
|
|
||
| @property | ||
| def nn_passthrough(self): | ||
| if self.nn is None: | ||
| raise RuntimeError("ExtendedNeuralNetwork not initialized") | ||
| return self.nn.passthrough | ||
Uh oh!
There was an error while loading. Please reload this page.