Skip to content

Commit 90d87ef

Browse files
committed
Add RandomBackgroundLines custom augmentation.
1 parent 0c6b120 commit 90d87ef

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

luxonis_ml/data/augmentations/custom/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .letterbox_resize import LetterboxResize
66
from .mixup import MixUp
77
from .mosaic import Mosaic4
8+
from .random_background_lines import RandomBackgroundLines
89
from .symetric_keypoints_flip import (
910
HorizontalSymetricKeypointsFlip,
1011
TransposeSymmetricKeypoints,
@@ -21,13 +22,15 @@
2122
TRANSFORMATIONS.register(module=HorizontalSymetricKeypointsFlip)
2223
TRANSFORMATIONS.register(module=VerticalSymetricKeypointsFlip)
2324
TRANSFORMATIONS.register(module=TransposeSymmetricKeypoints)
25+
TRANSFORMATIONS.register(module=RandomBackgroundLines)
2426

2527
__all__ = [
2628
"TRANSFORMATIONS",
2729
"HorizontalSymetricKeypointsFlip",
2830
"LetterboxResize",
2931
"MixUp",
3032
"Mosaic4",
33+
"RandomBackgroundLines",
3134
"TransposeSymmetricKeypoints",
3235
"VerticalSymetricKeypointsFlip",
3336
]
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import random
2+
from typing import Any
3+
4+
import albumentations as A
5+
import cv2
6+
import numpy as np
7+
from typing_extensions import override
8+
9+
10+
class RandomBackgroundLines(A.DualTransform):
11+
"""Randomly draws lines on the background of an image, avoiding foreground objects.
12+
13+
@type num_lines: tuple
14+
@param num_lines: Range of number of lines to draw. Defaults to (3, 10).
15+
@type line_thickness: tuple
16+
@param line_thickness: Range of line thickness. Defaults to (10, 50).
17+
@type line_length: tuple
18+
@param line_length: Range of line lengths as a fraction of the diagonal of the image. Defaults to (0.1, 0.5).
19+
@type p: float
20+
@param p: Probability of applying the transform. Defaults to 0.5.
21+
"""
22+
23+
def __init__(
24+
self,
25+
num_lines: tuple = (3, 10),
26+
line_thickness: tuple = (10, 50),
27+
line_length: tuple = (0.1, 0.5),
28+
p: float = 0.5,
29+
):
30+
super().__init__(p=p)
31+
self.num_lines = num_lines
32+
self.line_thickness = line_thickness
33+
self.line_length = line_length
34+
35+
@override
36+
def get_params_dependent_on_data(
37+
self, params: dict[str, Any], data: dict[str, Any]
38+
) -> dict[str, Any]:
39+
"""Updates augmentation parameters with the necessary metadata.
40+
41+
@param params: The existing augmentation parameters dictionary.
42+
@type params: Dict[str, Any]
43+
@param data: The data dictionary.
44+
@type data: Dict[str, Any]
45+
@return: Additional parameters for the augmentation.
46+
@rtype: Dict[str, Any]
47+
"""
48+
49+
seg_mask = data.get("_segmentation")
50+
if seg_mask.shape[-1] != 1:
51+
seg_mask = seg_mask[:, :, 0]
52+
return {
53+
"seg_mask": seg_mask,
54+
}
55+
56+
def apply(
57+
self, image: np.ndarray, seg_mask: np.ndarray, **params
58+
) -> np.ndarray:
59+
"""Applies the random background lines augmentation to the image.
60+
61+
@type image: np.ndarray
62+
@param image: The input image.
63+
@type seg_mask: np.ndarray
64+
@param seg_mask: The segmentation mask.
65+
@return: The augmented image with lines drawn on the background.
66+
@rtype: np.ndarray
67+
"""
68+
69+
result = image.copy()
70+
h, w = image.shape[:2]
71+
diagonal = np.sqrt(h**2 + w**2)
72+
73+
if seg_mask is None:
74+
raise ValueError("Mask is None. Please provide a valid mask.")
75+
76+
background_mask = seg_mask >= 0.5
77+
num_lines = random.randint(self.num_lines[0], self.num_lines[1])
78+
79+
for _ in range(num_lines):
80+
thickness = random.randint(
81+
self.line_thickness[0], self.line_thickness[1]
82+
)
83+
length = (
84+
random.uniform(self.line_length[0], self.line_length[1])
85+
* diagonal
86+
)
87+
88+
for _ in range(20):
89+
background_points = np.where(background_mask)
90+
if len(background_points[0]) == 0:
91+
continue
92+
93+
idx = random.randint(0, len(background_points[0]) - 1)
94+
y1 = background_points[0][idx]
95+
x1 = background_points[1][idx]
96+
97+
angle = random.choice(
98+
[0, np.pi / 4, np.pi / 2, 3 * np.pi / 4, np.pi]
99+
)
100+
101+
x2 = int(x1 + length * np.cos(angle))
102+
y2 = int(y1 + length * np.sin(angle))
103+
104+
if x2 < 0:
105+
y2 = int(
106+
y1 + (0 - x1) * np.tan(angle)
107+
if angle != np.pi / 2
108+
else y1
109+
)
110+
x2 = 0
111+
elif x2 >= w:
112+
y2 = int(
113+
y1 + (w - 1 - x1) * np.tan(angle)
114+
if angle != np.pi / 2
115+
else y1
116+
)
117+
x2 = w - 1
118+
119+
if y2 < 0:
120+
x2 = int(
121+
x1 + (0 - y1) / np.tan(angle)
122+
if np.tan(angle) != 0
123+
else x1
124+
)
125+
y2 = 0
126+
elif y2 >= h:
127+
x2 = int(
128+
x1 + (h - 1 - y1) / np.tan(angle)
129+
if np.tan(angle) != 0
130+
else x1
131+
)
132+
y2 = h - 1
133+
134+
line_mask = np.zeros((h, w), dtype=np.uint8)
135+
cv2.line(line_mask, (x1, y1), (x2, y2), 1, thickness)
136+
137+
foreground_mask = seg_mask < 0.5
138+
if np.any(np.logical_and(line_mask > 0, foreground_mask)):
139+
continue
140+
141+
color = (0, 0, 0)
142+
cv2.line(result, (x1, y1), (x2, y2), color, thickness)
143+
break
144+
145+
return result

0 commit comments

Comments
 (0)