Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions luxonis_train/attached_modules/visualizers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Visualizer for **bounding box detection task**.
| `width` | `int` | `1` | The width of the bounding box lines |
| `font` | `str \| None` | `None` | A filename containing a `TrueType` font |
| `font_size` | `int \| None` | `None` | Font size used for the labels |
| `scale` | `float \| 1.0` | `1.0` | Scales the canvas and the annotations by a given factor |

**Example:**

Expand All @@ -50,6 +51,8 @@ Visualizer for **instance keypoint detection task**.
| `connectivity` | `list[tuple[int, int]] \| None` | `None` | List of tuples of keypoint indices that define the connections in the skeleton |
| `visible_color` | `str \| tuple[int, int, int]` | `"red"` | Color of visible keypoints |
| `nonvisible_color` | `str \| tuple[int, int, int ] \| None` | `None` | Color of non-visible keypoints. If `None`, non-visible keypoints are not drawn |
| `radius` | `int \| None` | `None` | Radius of drawn keypoint dots. If `None`, dynamically determine this based on image dimensions |
| `scale` | `float \| 1.0` | `1.0` | Scales the canvas and the annotations by a given factor |

**Example:**

Expand All @@ -61,10 +64,11 @@ Visualizer for **segmentation tasks**.

**Parameters:**

| Key | Type | Default value | Description |
| ------- | ----------------------------- | ------------- | ------------------------------------- |
| `color` | `str \| tuple[int, int, int]` | `"#5050FF"` | Color of the segmentation masks |
| `alpha` | `float` | `0.6` | Alpha value of the segmentation masks |
| Key | Type | Default value | Description |
| ------- | ----------------------------- | ------------- | ------------------------------------------------------- |
| `color` | `str \| tuple[int, int, int]` | `"#5050FF"` | Color of the segmentation masks |
| `alpha` | `float` | `0.6` | Alpha value of the segmentation masks |
| `scale` | `float \| 1.0` | `1.0` | Scales the canvas and the annotations by a given factor |

**Example:**

Expand All @@ -83,6 +87,7 @@ Visualizer for **classification tasks**.
| `font_scale` | `float` | `1.0` | Scale of the font |
| `thickness` | `int` | `1` | Line thickness of the font |
| `multi_label` | `bool` | `False` | Set to `True` for multi-label classification, otherwise `False` for single-label |
| `scale` | `float \| 1.0` | `1.0` | Scales the canvas and the annotations by a given factor |

**Example:**

Expand All @@ -108,11 +113,12 @@ Visualizer for **OCR tasks**.

**Parameters:**

| Key | Type | Default value | Description |
| ------------ | ---------------------- | ------------- | ------------------------------------------- |
| `font_scale` | `float` | `0.5` | Font scale of the text. Defaults to `0.5`. |
| `color` | `tuple[int, int, int]` | `(0, 0, 0)` | Color of the text. Defaults to `(0, 0, 0)`. |
| `thickness` | `int` | `1` | Thickness of the text. Defaults to `1`. |
| Key | Type | Default value | Description |
| ------------ | ---------------------- | ------------- | ------------------------------------------------------- |
| `font_scale` | `float` | `0.5` | Font scale of the text. Defaults to `0.5`. |
| `color` | `tuple[int, int, int]` | `(0, 0, 0)` | Color of the text. Defaults to `(0, 0, 0)`. |
| `thickness` | `int` | `1` | Thickness of the text. Defaults to `1`. |
| `scale` | `float \| 1.0` | `1.0` | Scales the canvas and the annotations by a given factor |

**Example:**

Expand All @@ -132,6 +138,7 @@ Visualizer for **instance segmentation tasks**.
| `width` | `int` | `1` | The width of the bounding box lines |
| `font` | `str \| None` | `None` | A filename containing a `TrueType` font |
| `font_size` | `int \| None` | `None` | Font size used for the labels |
| `scale` | `float \| 1.0` | `1.0` | Scales the canvas and the annotations by a given factor |

**Example:**

Expand Down
17 changes: 17 additions & 0 deletions luxonis_train/attached_modules/visualizers/base_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import cached_property
from inspect import Parameter

import torch.nn.functional as F
from torch import Tensor
from typing_extensions import TypeVarTuple, Unpack

Expand All @@ -20,6 +21,19 @@ class BaseVisualizer(BaseAttachedModule, register=False, registry=VISUALIZERS):
L{VISUALIZERS} registry.
"""

def __init__(self, *args, scale: float = 1.0, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.scale = scale

@staticmethod
def scale_canvas(canvas: Tensor, scale: float = 1.0) -> Tensor:
return F.interpolate(
canvas,
scale_factor=scale,
mode="bilinear",
align_corners=False,
)

@abstractmethod
def forward(
self,
Expand Down Expand Up @@ -75,6 +89,9 @@ def run(
inputs: Packet[Tensor],
labels: Labels | None,
) -> Tensor | tuple[Tensor, Tensor] | tuple[Tensor, list[Tensor]]:
prediction_canvas = self.scale_canvas(prediction_canvas, self.scale)
target_canvas = self.scale_canvas(target_canvas, self.scale)

return self(
target_canvas,
prediction_canvas,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
draw_bounding_boxes,
draw_segmentation_targets,
get_color,
potentially_upscale_masks,
)


Expand Down Expand Up @@ -90,6 +91,7 @@ def draw_predictions(
color_dict: dict[str, Color],
draw_labels: bool,
alpha: float,
scale: float = 1.0,
) -> Tensor:
viz = torch.zeros_like(canvas)

Expand All @@ -99,6 +101,12 @@ def draw_predictions(
image_masks = pred_masks[i]
prediction_classes = image_bboxes[..., 5].int()

if scale is not None and scale != 1:
image_bboxes = image_bboxes.clone()
image_bboxes[:, :4] *= scale

image_masks = potentially_upscale_masks(image_masks, scale)

cls_labels = (
[label_dict[int(c)] for c in prediction_classes]
if draw_labels and label_dict is not None
Expand Down Expand Up @@ -143,6 +151,7 @@ def draw_targets(
color_dict: dict[str, Color],
draw_labels: bool,
alpha: float,
scale: float = 1.0,
) -> Tensor:
viz = torch.zeros_like(canvas)

Expand All @@ -152,6 +161,8 @@ def draw_targets(
image_masks = target_masks[target_bboxes[:, 0] == i]
target_classes = image_bboxes[:, 1].int()

image_masks = potentially_upscale_masks(image_masks, scale)

cls_labels = (
[label_dict[int(c)] for c in target_classes]
if draw_labels and label_dict is not None
Expand Down Expand Up @@ -219,10 +230,10 @@ def forward(
self.colors,
self.draw_labels,
self.alpha,
self.scale,
)
if target_boundingbox is None or target_instance_segmentation is None:
return predictions_viz

targets_viz = self.draw_targets(
target_canvas,
target_boundingbox,
Expand All @@ -232,5 +243,6 @@ def forward(
self.colors,
self.draw_labels,
self.alpha,
self.scale,
)
return targets_viz, predictions_viz
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
connectivity: list[tuple[int, int]] | None = None,
visible_color: Color = "red",
nonvisible_color: Color | None = None,
radius: int | None = None,
**kwargs,
):
"""Visualizer for keypoints.
Expand All @@ -37,22 +38,44 @@ def __init__(
@param nonvisible_color: Color of nonvisible keypoints. If
C{None}, nonvisible keypoints are not drawn. Defaults to
C{None}.
@type radius: int | None
@param radius: the radius of drawn keypoints
"""
super().__init__(**kwargs)
self.visibility_threshold = visibility_threshold
self.connectivity = connectivity
self.visible_color = visible_color
self.nonvisible_color = nonvisible_color
self.radius = radius

@staticmethod
def _get_radius(canvas: Tensor) -> int:
"""Determine keypoint radius based on image size.

If the image is under 128 in both height and width: 1.
If the image is more than 512 in either height or width: 6.
Otherwise: 3.
"""
height = canvas.size(-2)
width = canvas.size(-1)

if height < 96 and width < 96:
return 1
if height > 512 or width > 512:
return 5
return 2

@staticmethod
def draw_predictions(
canvas: Tensor,
predictions: list[Tensor],
nonvisible_color: Color | None = None,
visibility_threshold: float = 0.5,
radius: int | None = None,
**kwargs,
) -> Tensor:
viz = torch.zeros_like(canvas)

for i in range(len(canvas)):
prediction = predictions[i]
mask = prediction[..., 2] < visibility_threshold
Expand All @@ -63,27 +86,42 @@ def draw_predictions(
visible_kpts[..., 1] = visible_kpts[..., 1].clamp(
0, canvas.size(-2) - 1
)

_kwargs = deepcopy(kwargs)
_kwargs.setdefault("radius", radius)

viz[i] = draw_keypoints(
canvas[i].clone(), visible_kpts[..., :2].int(), **kwargs
canvas[i].clone(),
visible_kpts[..., :2].int(),
**_kwargs,
)

if nonvisible_color is not None:
_kwargs = deepcopy(kwargs)
_kwargs.setdefault("radius", radius)
_kwargs["colors"] = nonvisible_color
nonvisible_kpts = (
prediction[..., :2] * mask.unsqueeze(-1).float()
)
viz[i] = draw_keypoints(
viz[i].clone(), nonvisible_kpts[..., :2], **_kwargs
viz[i].clone(),
nonvisible_kpts[..., :2],
**_kwargs,
)

return viz

@staticmethod
def draw_targets(canvas: Tensor, targets: Tensor, **kwargs) -> Tensor:
viz = torch.zeros_like(canvas)

for i in range(len(canvas)):
target = targets[targets[:, 0] == i][:, 1:]
viz[i] = draw_keypoint_labels(canvas[i].clone(), target, **kwargs)
viz[i] = draw_keypoint_labels(
canvas[i].clone(),
target,
**kwargs,
)

return viz

Expand All @@ -99,13 +137,25 @@ def forward(
) -> tuple[Tensor, Tensor] | Tensor:
pred_viz = super().draw_predictions(prediction_canvas, boundingbox)

prediction_radius = (
KeypointVisualizer._get_radius(prediction_canvas)
if self.radius is None
else self.radius
)
target_radius = (
KeypointVisualizer._get_radius(target_canvas)
if self.radius is None
else self.radius
)

pred_viz = self.draw_predictions(
pred_viz,
keypoints,
connectivity=self.connectivity,
colors=self.visible_color,
nonvisible_color=self.nonvisible_color,
visibility_threshold=self.visibility_threshold,
radius=prediction_radius,
**kwargs,
)

Expand All @@ -123,6 +173,7 @@ def forward(
target_viz = self.draw_targets(
target_viz,
target_keypoints,
radius=target_radius,
colors=self.visible_color,
connectivity=self.connectivity,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from luxonis_train.utils import seg_output_to_bool

from .base_visualizer import BaseVisualizer
from .utils import Color, draw_segmentation_targets
from .utils import Color, draw_segmentation_targets, potentially_upscale_masks

log_disable = False

Expand Down Expand Up @@ -53,24 +53,34 @@ def __init__(

@staticmethod
def draw_predictions(
canvas: Tensor, predictions: Tensor, alpha: float, colors: list[Color]
canvas: Tensor,
predictions: Tensor,
alpha: float,
colors: list[Color],
scale: float = 1.0,
) -> Tensor:
viz = torch.zeros_like(canvas)
for i in range(len(canvas)):
prediction = predictions[i]
mask = seg_output_to_bool(prediction)
mask = potentially_upscale_masks(mask, scale)
viz[i] = draw_segmentation_targets(
canvas[i].clone(), mask, alpha=alpha, colors=colors
).to(canvas.device)
return viz

@staticmethod
def draw_targets(
canvas: Tensor, targets: Tensor, alpha: float, colors: list[Color]
canvas: Tensor,
targets: Tensor,
alpha: float,
colors: list[Color],
scale: float = 1.0,
) -> Tensor:
viz = torch.zeros_like(canvas)
for i in range(len(viz)):
target = targets[i]
target = targets[i].bool()
target = potentially_upscale_masks(target, scale)
viz[i] = draw_segmentation_targets(
canvas[i].clone(), target, alpha=alpha, colors=colors
).to(canvas.device)
Expand Down Expand Up @@ -107,6 +117,7 @@ def forward(
predictions,
alpha=self.alpha,
colors=colors,
scale=self.scale,
)
if target is None:
return predictions_vis
Expand All @@ -116,6 +127,7 @@ def forward(
target,
alpha=self.alpha,
colors=colors,
scale=self.scale,
)
return targets_vis, predictions_vis

Expand Down
Loading
Loading