Skip to content
17 changes: 17 additions & 0 deletions luxonis_train/attached_modules/visualizers/base_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import cached_property
from inspect import Parameter

import torch.nn.functional as F
from torch import Tensor
from typing_extensions import TypeVarTuple, Unpack

Expand All @@ -20,6 +21,19 @@ class BaseVisualizer(BaseAttachedModule, register=False, registry=VISUALIZERS):
L{VISUALIZERS} registry.
"""

def __init__(self, *args, scale: float = 1.0, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.scale = scale

@staticmethod
def scale_canvas(canvas: Tensor, scale: float = 1.0) -> Tensor:
return F.interpolate(
canvas,
scale_factor=scale,
mode="bilinear",
align_corners=False,
)

@abstractmethod
def forward(
self,
Expand Down Expand Up @@ -75,6 +89,9 @@ def run(
inputs: Packet[Tensor],
labels: Labels | None,
) -> Tensor | tuple[Tensor, Tensor] | tuple[Tensor, list[Tensor]]:
prediction_canvas = self.scale_canvas(prediction_canvas, self.scale)
target_canvas = self.scale_canvas(target_canvas, self.scale)

return self(
target_canvas,
prediction_canvas,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
draw_bounding_boxes,
draw_segmentation_targets,
get_color,
potentially_upscale_masks,
)


Expand Down Expand Up @@ -90,6 +91,7 @@ def draw_predictions(
color_dict: dict[str, Color],
draw_labels: bool,
alpha: float,
scale: float = 1.0,
) -> Tensor:
viz = torch.zeros_like(canvas)

Expand All @@ -99,6 +101,12 @@ def draw_predictions(
image_masks = pred_masks[i]
prediction_classes = image_bboxes[..., 5].int()

if scale is not None and scale != 1:
image_bboxes = image_bboxes.clone()
image_bboxes[:, :4] *= scale

image_masks = potentially_upscale_masks(image_masks, scale)

cls_labels = (
[label_dict[int(c)] for c in prediction_classes]
if draw_labels and label_dict is not None
Expand Down Expand Up @@ -143,6 +151,7 @@ def draw_targets(
color_dict: dict[str, Color],
draw_labels: bool,
alpha: float,
scale: float = 1.0,
) -> Tensor:
viz = torch.zeros_like(canvas)

Expand All @@ -152,6 +161,8 @@ def draw_targets(
image_masks = target_masks[target_bboxes[:, 0] == i]
target_classes = image_bboxes[:, 1].int()

image_masks = potentially_upscale_masks(image_masks, scale)

cls_labels = (
[label_dict[int(c)] for c in target_classes]
if draw_labels and label_dict is not None
Expand Down Expand Up @@ -219,10 +230,10 @@ def forward(
self.colors,
self.draw_labels,
self.alpha,
self.scale,
)
if target_boundingbox is None or target_instance_segmentation is None:
return predictions_viz

targets_viz = self.draw_targets(
target_canvas,
target_boundingbox,
Expand All @@ -232,5 +243,6 @@ def forward(
self.colors,
self.draw_labels,
self.alpha,
self.scale,
)
return targets_viz, predictions_viz
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ def __init__(
self.visible_color = visible_color
self.nonvisible_color = nonvisible_color

@staticmethod
def _get_radius(canvas: Tensor) -> int:
"""Determine keypoint radius based on image size.

If the image is under 128 in both height and width: 1.
If the image is more than 512 in either height or width: 6.
Otherwise: 3.
"""
height = canvas.size(-2)
width = canvas.size(-1)

if height < 128 and width < 128:
return 1
if height > 512 or width > 512:
return 6
return 3

@staticmethod
def draw_predictions(
canvas: Tensor,
Expand All @@ -53,6 +70,8 @@ def draw_predictions(
**kwargs,
) -> Tensor:
viz = torch.zeros_like(canvas)
radius = KeypointVisualizer._get_radius(canvas)

for i in range(len(canvas)):
prediction = predictions[i]
mask = prediction[..., 2] < visibility_threshold
Expand All @@ -63,27 +82,46 @@ def draw_predictions(
visible_kpts[..., 1] = visible_kpts[..., 1].clamp(
0, canvas.size(-2) - 1
)

_kwargs = deepcopy(kwargs)
_kwargs.setdefault("radius", radius)

viz[i] = draw_keypoints(
canvas[i].clone(), visible_kpts[..., :2].int(), **kwargs
canvas[i].clone(),
visible_kpts[..., :2].int(),
**_kwargs,
)

if nonvisible_color is not None:
_kwargs = deepcopy(kwargs)
_kwargs.setdefault("radius", radius)
_kwargs["colors"] = nonvisible_color
nonvisible_kpts = (
prediction[..., :2] * mask.unsqueeze(-1).float()
)
viz[i] = draw_keypoints(
viz[i].clone(), nonvisible_kpts[..., :2], **_kwargs
viz[i].clone(),
nonvisible_kpts[..., :2],
**_kwargs,
)

return viz

@staticmethod
def draw_targets(canvas: Tensor, targets: Tensor, **kwargs) -> Tensor:
viz = torch.zeros_like(canvas)
radius = KeypointVisualizer._get_radius(canvas)

_kwargs = deepcopy(kwargs)
_kwargs.setdefault("radius", radius)

for i in range(len(canvas)):
target = targets[targets[:, 0] == i][:, 1:]
viz[i] = draw_keypoint_labels(canvas[i].clone(), target, **kwargs)
viz[i] = draw_keypoint_labels(
canvas[i].clone(),
target,
**_kwargs,
)

return viz

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from luxonis_train.utils import seg_output_to_bool

from .base_visualizer import BaseVisualizer
from .utils import Color, draw_segmentation_targets
from .utils import Color, draw_segmentation_targets, potentially_upscale_masks

log_disable = False

Expand Down Expand Up @@ -53,24 +53,34 @@ def __init__(

@staticmethod
def draw_predictions(
canvas: Tensor, predictions: Tensor, alpha: float, colors: list[Color]
canvas: Tensor,
predictions: Tensor,
alpha: float,
colors: list[Color],
scale: float = 1.0,
) -> Tensor:
viz = torch.zeros_like(canvas)
for i in range(len(canvas)):
prediction = predictions[i]
mask = seg_output_to_bool(prediction)
mask = potentially_upscale_masks(mask, scale)
viz[i] = draw_segmentation_targets(
canvas[i].clone(), mask, alpha=alpha, colors=colors
).to(canvas.device)
return viz

@staticmethod
def draw_targets(
canvas: Tensor, targets: Tensor, alpha: float, colors: list[Color]
canvas: Tensor,
targets: Tensor,
alpha: float,
colors: list[Color],
scale: float = 1.0,
) -> Tensor:
viz = torch.zeros_like(canvas)
for i in range(len(viz)):
target = targets[i]
target = targets[i].bool()
target = potentially_upscale_masks(target, scale)
viz[i] = draw_segmentation_targets(
canvas[i].clone(), target, alpha=alpha, colors=colors
).to(canvas.device)
Expand Down Expand Up @@ -107,6 +117,7 @@ def forward(
predictions,
alpha=self.alpha,
colors=colors,
scale=self.scale,
)
if target is None:
return predictions_vis
Expand All @@ -116,6 +127,7 @@ def forward(
target,
alpha=self.alpha,
colors=colors,
scale=self.scale,
)
return targets_vis, predictions_vis

Expand Down
26 changes: 24 additions & 2 deletions luxonis_train/attached_modules/visualizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import numpy.typing as npt
import torch
import torchvision.transforms.functional as F
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from matplotlib.figure import Figure
from PIL import Image
Expand Down Expand Up @@ -201,7 +201,7 @@ def denormalize(
std_tensor = torch.tensor(std, device=img.device)
new_mean = -mean_tensor / std_tensor
new_std = 1 / std_tensor
out_img = F.normalize(img, mean=new_mean.tolist(), std=new_std.tolist())
out_img = TF.normalize(img, mean=new_mean.tolist(), std=new_std.tolist())
if to_uint8:
out_img = out_img.mul_(255).clamp_(0, 255).to(torch.uint8)
return out_img
Expand Down Expand Up @@ -264,6 +264,28 @@ def dynamically_determine_font_scale(
return computed_font_scale, thickness


def potentially_upscale_masks(
image_masks: Tensor, scale: float = 1.0
) -> Tensor:
"""Upscales boolean segmentation masks.

@param image_masks:
@param scale: scale factor
@return: Upscaled image masks
"""
if scale is not None and scale != 1:
image_masks = image_masks.unsqueeze(1)
H_orig, W_orig = image_masks.shape[-2:]
H_up = int(H_orig * scale)
W_up = int(W_orig * scale)

image_masks = F.interpolate(
image_masks.float(), size=(H_up, W_up), mode="nearest"
).bool()
return image_masks.squeeze(1).bool()
return image_masks


# TODO: Support native visualizations
# NOTE: Ignore for now, native visualizations not a priority.
#
Expand Down
Loading