Skip to content
2 changes: 2 additions & 0 deletions depthai_nodes/node/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .apply_colormap import ApplyColormap
from .base_host_node import BaseHostNode
from .depth_merger import DepthMerger
from .host_crop_config_creator import CropConfigsCreator
from .host_parsing_neural_network import HostParsingNeuralNetwork
from .host_spatials_calc import HostSpatialsCalc
from .img_detections_bridge import ImgDetectionsBridge
Expand Down Expand Up @@ -63,6 +64,7 @@
"BaseParser",
"DetectionParser",
"EmbeddingsParser",
"CropConfigsCreator",
"ImgFrameOverlay",
"ImgDetectionsBridge",
"ImgDetectionsFilter",
Expand Down
260 changes: 260 additions & 0 deletions depthai_nodes/node/host_crop_config_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
from typing import Optional, Tuple

import depthai as dai

from depthai_nodes import ImgDetectionExtended, ImgDetectionsExtended


class CropConfigsCreator(dai.node.HostNode):
"""A node to create and send a dai.ImageManipConfigV2 crop configuration for each
detection in a list of detections. An optional target size and resize mode can be
set to ensure uniform crop sizes.

To ensure correct synchronization between the crop configurations and the image,
ensure "inputConfig.setReusePreviousMessage" is set to False in the dai.ImageManipV2 node.

Attributes
----------
detections_input : dai.Input
The input link for the ImageDetectionsExtended | dai.ImgDetections message.
config_output : dai.Output
The output link for the ImageManipConfigV2 messages.
detections_output : dai.Output
The output link for the ImgDetectionsExtended message.
source_size : Tuple[int, int]
The size of the source image (width, height).
target_size : Optional[Tuple[int, int]] = None
The size of the target image (width, height). If None, crop sizes will not be uniform.
resize_mode : dai.ImageManipConfigV2.ResizeMode = dai.ImageManipConfigV2.ResizeMode.STRETCH
The resize mode to use when target size is set. Options are: CENTER_CROP, LETTERBOX, NONE, STRETCH.
"""

def __init__(self) -> None:
"""Initializes the node."""
super().__init__()
self.config_output = self.createOutput(
possibleDatatypes=[
dai.Node.DatatypeHierarchy(dai.DatatypeEnum.ImageManipConfigV2, True)
]
)
self.detections_output = self.createOutput(
possibleDatatypes=[
dai.Node.DatatypeHierarchy(dai.DatatypeEnum.Buffer, True)
]
)
self._w: int = None
self._h: int = None
self._target_w: int = None
self._target_h: int = None
self.resize_mode: dai.ImageManipConfigV2.ResizeMode = None

@property
def w(self) -> int:
"""Returns the width of the source image.

@return: Width of the source image.
@rtype: int
"""
return self._w

@property
def h(self) -> int:
"""Returns the height of the source image.

@return: Height of the source image.
@rtype: int
"""
return self._h

@property
def target_w(self) -> int:
"""Returns the width of the target image.

@return: Width of the target image.
@rtype: int
"""
return self._target_w

@property
def target_h(self) -> int:
"""Returns the height of the target image.

@return: Height of the target image.
@rtype: int
"""
return self._target_h

@w.setter
def w(self, w: int):
"""Sets the width of the source image.

@param w: Width of the source image.
@type w: int
@raise TypeError: If w is not an integer.
@raise ValueError: If w is less than 1.
"""
self._validate_positive_integer(w)
self._w = w

@h.setter
def h(self, h: int):
"""Sets the height of the source image.

@param h: Height of the source image.
@type h: int
@raise TypeError: If h is not an integer.
@raise ValueError: If h is less than 1.
"""
self._validate_positive_integer(h)
self._h = h

@target_w.setter
def target_w(self, target_w: int):
"""Sets the width of the target image.

@param target_w: Width of the target image.
@type target_w: int
@raise TypeError: If target_w is not an integer.
@raise ValueError: If target_w is less than 1.
"""
self._validate_positive_integer(target_w)
self._target_w = target_w

@target_h.setter
def target_h(self, target_h: int):
"""Sets the height of the target image.

@param target_h: Height of the target image.
@type target_h: int
@raise TypeError: If target_h is not an integer.
@raise ValueError: If target_h is less than 1.
"""
self._validate_positive_integer(target_h)
self._target_h = target_h

def build(
self,
detections_input: dai.Node.Output,
source_size: Tuple[int, int],
target_size: Optional[Tuple[int, int]] = None,
resize_mode: dai.ImageManipConfigV2.ResizeMode = dai.ImageManipConfigV2.ResizeMode.STRETCH,
) -> "CropConfigsCreator":
"""Link the node input and set the correct source and target image sizes.

Parameters
----------
detections_input : dai.Node.Output
The input link for the ImgDetectionsExtended message
source_size : Tuple[int, int]
The size of the source image (width, height).
target_size : Optional[Tuple[int, int]]
The size of the target image (width, height). If None, crop sizes will not be uniform.
resize_mode : dai.ImageManipConfigV2.ResizeMode = dai.ImageManipConfigV2.ResizeMode.STRETCH
The resize mode to use when target size is set. Options are: CENTER_CROP, LETTERBOX, NONE, STRETCH.
"""

self.w = source_size[0]
self.h = source_size[1]

if target_size is not None:
self.target_w = target_size[0]
self.target_h = target_size[1]

self.resize_mode = resize_mode

self.link_args(detections_input)

return self

def process(self, detections_input: dai.Buffer) -> None:
"""Process the input detections and create crop configurations. This function is
ran every time a new ImgDetectionsExtended or dai.ImgDetections message is
received.

Sends len(detections) number of crop configurations to the config_output link.
In addition sends an ImgDetectionsExtended object containing the corresponding
detections to the detections_output link.
"""

assert isinstance(detections_input, (ImgDetectionsExtended, dai.ImgDetections))

sequence_num = detections_input.getSequenceNum()
timestamp = detections_input.getTimestamp()

if isinstance(detections_input, dai.ImgDetections):
detections_msg = self._convert_to_extended(detections_input)
else:
detections_msg = detections_input

detections = detections_msg.detections

# Skip the current frame / load new frame
cfg = dai.ImageManipConfigV2()
cfg.setSkipCurrentImage(True)
cfg.setTimestamp(timestamp)
cfg.setSequenceNum(sequence_num)
send_status = False
while not send_status:
send_status = self.config_output.trySend(cfg)

for i in range(len(detections)):
cfg = dai.ImageManipConfigV2()
detection: ImgDetectionExtended = detections[i]
rect = detection.rotated_rect
rect = rect.denormalize(self.w, self.h)

cfg.addCropRotatedRect(rect, normalizedCoords=False)

if self.target_w is not None and self.target_h is not None:
cfg.setOutputSize(self.target_w, self.target_h, self.resize_mode)

cfg.setReusePreviousImage(True)
cfg.setTimestamp(timestamp)
cfg.setSequenceNum(sequence_num)

send_status = False
while not send_status:
send_status = self.config_output.trySend(cfg)

self.detections_output.send(detections_msg)

def _convert_to_extended(
self, detections: dai.ImgDetections
) -> ImgDetectionsExtended:
rotated_rectangle_detections = []
for det in detections.detections:
img_detection = ImgDetectionExtended()
img_detection.label = det.label
img_detection.confidence = det.confidence

x_center = (det.xmin + det.xmax) / 2
y_center = (det.ymin + det.ymax) / 2
width = det.xmax - det.xmin
height = det.ymax - det.ymin

img_detection.rotated_rect = (x_center, y_center, width, height, 0.0)

rotated_rectangle_detections.append(img_detection)

img_detections_extended = ImgDetectionsExtended()
img_detections_extended.setSequenceNum(detections.getSequenceNum())
img_detections_extended.setTimestamp(detections.getTimestamp())
img_detections_extended.detections = rotated_rectangle_detections
transformation = detections.getTransformation()
if transformation is not None:
img_detections_extended.setTransformation(transformation)

return img_detections_extended

def _validate_positive_integer(self, value: int):
"""Validates that the set size is a positive integer.

@param value: The value to validate.
@type value: int
@raise TypeError: If value is not an integer.
@raise ValueError: If value is less than 1.
"""
if not isinstance(value, int):
raise TypeError("Value must be an integer.")
if value < 1:
raise ValueError("Value must be greater than 1.")
15 changes: 14 additions & 1 deletion tests/unittests/test_nodes/test_host_nodes/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def __init__(self):
def get(self):
return self._messages.pop(0)

def get_all(self):
return self._messages

def send(self, item):
self._messages.append(item)

Expand All @@ -32,6 +35,14 @@ def send(self, message):
for queue in self._queues:
queue.send(message)

def trySend(self, message):
for queue in self._queues:
queue.send(message)
return True

def returnQueue(self):
return self._queues

def createOutputQueue(self):
queue = Queue()
self._queues.append(queue)
Expand Down Expand Up @@ -84,7 +95,9 @@ def output(self, output):
self._output = output

def createOutput(self, possibleDatatypes: List[Tuple[dai.DatatypeEnum, bool]]):
return self._output
o = Output()
o.setPossibleDatatypes(possibleDatatypes)
return o

def sendProcessingToPipeline(self, send: bool):
self._sendProcessingToPipeline = send
Expand Down
Loading
Loading