Skip to content

Commit 459cdaa

Browse files
committed
Added automatic clipping for normalized annotations
1 parent 4b84d04 commit 459cdaa

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

luxonis_ml/data/datasets/annotation.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
from abc import ABC, abstractmethod
34
from datetime import datetime
45
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, TypedDict, Union
@@ -20,6 +21,8 @@
2021

2122
from ..utils.enums import LabelType
2223

24+
logger = logging.getLogger(__name__)
25+
2326
KeypointVisibility: TypeAlias = Literal[0, 1, 2]
2427
NormalizedFloat: TypeAlias = Annotated[float, Field(ge=0, le=1)]
2528
"""C{NormalizedFloat} is a float that is restricted to the range [0, 1]."""
@@ -137,6 +140,20 @@ class BBoxAnnotation(Annotation):
137140

138141
_label_type = LabelType.BOUNDINGBOX
139142

143+
@model_validator(mode="before")
144+
@classmethod
145+
def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
146+
warn = False
147+
for key in ["x", "y", "w", "h"]:
148+
if not (0 <= values[key] <= 1):
149+
warn = True
150+
values[key] = max(0, min(1, values[key]))
151+
if warn:
152+
logger.warning(
153+
"BBox annotation has values outside of [0,1] range. Clipping them to [0,1]."
154+
)
155+
return values
156+
140157
def to_numpy(self, class_mapping: Dict[str, int]) -> np.ndarray:
141158
class_ = class_mapping.get(self.class_, 0)
142159
return np.array([class_, self.x, self.y, self.w, self.h])
@@ -170,6 +187,26 @@ class KeypointAnnotation(Annotation):
170187

171188
_label_type = LabelType.KEYPOINTS
172189

190+
@model_validator(mode="before")
191+
@classmethod
192+
def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
193+
warn = False
194+
for i, keypoint in enumerate(values["keypoints"]):
195+
new_keypoint = list(keypoint)
196+
if not (0 <= keypoint[0] <= 1):
197+
new_keypoint[0] = max(0, min(1, keypoint[0]))
198+
warn = True
199+
if not (0 <= keypoint[1] <= 1):
200+
new_keypoint[1] = max(0, min(1, keypoint[1]))
201+
warn = True
202+
values["keypoints"][i] = tuple(new_keypoint)
203+
204+
if warn:
205+
logger.warning(
206+
"Keypoint annotation has values outside of [0,1] range. Clipping them to [0,1]."
207+
)
208+
return values
209+
173210
def to_numpy(self, class_mapping: Dict[str, int]) -> np.ndarray:
174211
class_ = class_mapping.get(self.class_, 0)
175212
kps = np.array(self.keypoints).reshape((-1, 3)).astype(np.float32)
@@ -340,6 +377,26 @@ class PolylineSegmentationAnnotation(SegmentationAnnotation):
340377

341378
points: List[Tuple[NormalizedFloat, NormalizedFloat]] = Field(min_length=3)
342379

380+
@model_validator(mode="before")
381+
@classmethod
382+
def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
383+
warn = False
384+
for i, point in enumerate(values["points"]):
385+
new_point = list(point)
386+
if not (0 <= point[0] <= 1):
387+
new_point[0] = max(0, min(1, point[0]))
388+
warn = True
389+
if not (0 <= point[1] <= 1):
390+
new_point[1] = max(0, min(1, point[1]))
391+
warn = True
392+
values["points"][i] = tuple(new_point)
393+
394+
if warn:
395+
logger.warning(
396+
"Polyline annotation has values outside of [0,1] range. Clipping them to [0,1]."
397+
)
398+
return values
399+
343400
def to_numpy(self, _: Dict[str, int], width: int, height: int) -> np.ndarray:
344401
polyline = [(round(x * width), round(y * height)) for x, y in self.points]
345402
mask = Image.new("L", (width, height), 0)

0 commit comments

Comments
 (0)