Skip to content

Commit b830e8e

Browse files
klemen1999kozlov721
authored andcommitted
Added automatic clipping for normalized annotations (#177)
1 parent 3b77638 commit b830e8e

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

luxonis_ml/data/datasets/annotation.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
from abc import ABC, abstractmethod
34
from datetime import datetime
45
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, TypedDict, Union
@@ -20,6 +21,8 @@
2021

2122
from ..utils.enums import LabelType
2223

24+
logger = logging.getLogger(__name__)
25+
2326
KeypointVisibility: TypeAlias = Literal[0, 1, 2]
2427
NormalizedFloat: TypeAlias = Annotated[float, Field(ge=0, le=1)]
2528
"""C{NormalizedFloat} is a float that is restricted to the range [0, 1]."""
@@ -137,6 +140,25 @@ class BBoxAnnotation(Annotation):
137140

138141
_label_type = LabelType.BOUNDINGBOX
139142

143+
@model_validator(mode="before")
144+
@classmethod
145+
def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
146+
warn = False
147+
for key in ["x", "y", "w", "h"]:
148+
if values[key] < -2 or values[key] > 2:
149+
raise ValueError(
150+
"BBox annotation has value outside of automatic clipping range ([-2, 2]). "
151+
"Values should be normalized based on image size to range [0, 1]."
152+
)
153+
if not (0 <= values[key] <= 1):
154+
warn = True
155+
values[key] = max(0, min(1, values[key]))
156+
if warn:
157+
logger.warning(
158+
"BBox annotation has values outside of [0, 1] range. Clipping them to [0, 1]."
159+
)
160+
return values
161+
140162
def to_numpy(self, class_mapping: Dict[str, int]) -> np.ndarray:
141163
class_ = class_mapping.get(self.class_, 0)
142164
return np.array([class_, self.x, self.y, self.w, self.h])
@@ -170,6 +192,33 @@ class KeypointAnnotation(Annotation):
170192

171193
_label_type = LabelType.KEYPOINTS
172194

195+
@model_validator(mode="before")
196+
@classmethod
197+
def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
198+
warn = False
199+
for i, keypoint in enumerate(values["keypoints"]):
200+
if (keypoint[0] < -2 or keypoint[0] > 2) or (
201+
keypoint[1] < -2 or keypoint[1] > 2
202+
):
203+
raise ValueError(
204+
"Keypoint annotation has value outside of automatic clipping range ([-2, 2]). "
205+
"Values should be normalized based on image size to range [0, 1]."
206+
)
207+
new_keypoint = list(keypoint)
208+
if not (0 <= keypoint[0] <= 1):
209+
new_keypoint[0] = max(0, min(1, keypoint[0]))
210+
warn = True
211+
if not (0 <= keypoint[1] <= 1):
212+
new_keypoint[1] = max(0, min(1, keypoint[1]))
213+
warn = True
214+
values["keypoints"][i] = tuple(new_keypoint)
215+
216+
if warn:
217+
logger.warning(
218+
"Keypoint annotation has values outside of [0, 1] range. Clipping them to [0, 1]."
219+
)
220+
return values
221+
173222
def to_numpy(self, class_mapping: Dict[str, int]) -> np.ndarray:
174223
class_ = class_mapping.get(self.class_, 0)
175224
kps = np.array(self.keypoints).reshape((-1, 3)).astype(np.float32)
@@ -340,6 +389,31 @@ class PolylineSegmentationAnnotation(SegmentationAnnotation):
340389

341390
points: List[Tuple[NormalizedFloat, NormalizedFloat]] = Field(min_length=3)
342391

392+
@model_validator(mode="before")
393+
@classmethod
394+
def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
395+
warn = False
396+
for i, point in enumerate(values["points"]):
397+
if (point[0] < -2 or point[0] > 2) or (point[1] < -2 or point[1] > 2):
398+
raise ValueError(
399+
"Polyline annotation has value outside of automatic clipping range ([-2, 2]). "
400+
"Values should be normalized based on image size to range [0, 1]."
401+
)
402+
new_point = list(point)
403+
if not (0 <= point[0] <= 1):
404+
new_point[0] = max(0, min(1, point[0]))
405+
warn = True
406+
if not (0 <= point[1] <= 1):
407+
new_point[1] = max(0, min(1, point[1]))
408+
warn = True
409+
values["points"][i] = tuple(new_point)
410+
411+
if warn:
412+
logger.warning(
413+
"Polyline annotation has values outside of [0, 1] range. Clipping them to [0, 1]."
414+
)
415+
return values
416+
343417
def to_numpy(self, _: Dict[str, int], width: int, height: int) -> np.ndarray:
344418
polyline = [(round(x * width), round(y * height)) for x, y in self.points]
345419
mask = Image.new("L", (width, height), 0)

0 commit comments

Comments
 (0)