|
1 | 1 | import json |
| 2 | +import logging |
2 | 3 | from abc import ABC, abstractmethod |
3 | 4 | from datetime import datetime |
4 | 5 | from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, TypedDict, Union |
|
20 | 21 |
|
21 | 22 | from ..utils.enums import LabelType |
22 | 23 |
|
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
23 | 26 | KeypointVisibility: TypeAlias = Literal[0, 1, 2] |
24 | 27 | NormalizedFloat: TypeAlias = Annotated[float, Field(ge=0, le=1)] |
25 | 28 | """C{NormalizedFloat} is a float that is restricted to the range [0, 1].""" |
@@ -137,6 +140,20 @@ class BBoxAnnotation(Annotation): |
137 | 140 |
|
138 | 141 | _label_type = LabelType.BOUNDINGBOX |
139 | 142 |
|
| 143 | + @model_validator(mode="before") |
| 144 | + @classmethod |
| 145 | + def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: |
| 146 | + warn = False |
| 147 | + for key in ["x", "y", "w", "h"]: |
| 148 | + if not (0 <= values[key] <= 1): |
| 149 | + warn = True |
| 150 | + values[key] = max(0, min(1, values[key])) |
| 151 | + if warn: |
| 152 | + logger.warning( |
| 153 | + "BBox annotation has values outside of [0,1] range. Clipping them to [0,1]." |
| 154 | + ) |
| 155 | + return values |
| 156 | + |
140 | 157 | def to_numpy(self, class_mapping: Dict[str, int]) -> np.ndarray: |
141 | 158 | class_ = class_mapping.get(self.class_, 0) |
142 | 159 | return np.array([class_, self.x, self.y, self.w, self.h]) |
@@ -170,6 +187,26 @@ class KeypointAnnotation(Annotation): |
170 | 187 |
|
171 | 188 | _label_type = LabelType.KEYPOINTS |
172 | 189 |
|
| 190 | + @model_validator(mode="before") |
| 191 | + @classmethod |
| 192 | + def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: |
| 193 | + warn = False |
| 194 | + for i, keypoint in enumerate(values["keypoints"]): |
| 195 | + new_keypoint = list(keypoint) |
| 196 | + if not (0 <= keypoint[0] <= 1): |
| 197 | + new_keypoint[0] = max(0, min(1, keypoint[0])) |
| 198 | + warn = True |
| 199 | + if not (0 <= keypoint[1] <= 1): |
| 200 | + new_keypoint[1] = max(0, min(1, keypoint[1])) |
| 201 | + warn = True |
| 202 | + values["keypoints"][i] = tuple(new_keypoint) |
| 203 | + |
| 204 | + if warn: |
| 205 | + logger.warning( |
| 206 | + "Keypoint annotation has values outside of [0,1] range. Clipping them to [0,1]." |
| 207 | + ) |
| 208 | + return values |
| 209 | + |
173 | 210 | def to_numpy(self, class_mapping: Dict[str, int]) -> np.ndarray: |
174 | 211 | class_ = class_mapping.get(self.class_, 0) |
175 | 212 | kps = np.array(self.keypoints).reshape((-1, 3)).astype(np.float32) |
@@ -340,6 +377,26 @@ class PolylineSegmentationAnnotation(SegmentationAnnotation): |
340 | 377 |
|
341 | 378 | points: List[Tuple[NormalizedFloat, NormalizedFloat]] = Field(min_length=3) |
342 | 379 |
|
| 380 | + @model_validator(mode="before") |
| 381 | + @classmethod |
| 382 | + def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: |
| 383 | + warn = False |
| 384 | + for i, point in enumerate(values["points"]): |
| 385 | + new_point = list(point) |
| 386 | + if not (0 <= point[0] <= 1): |
| 387 | + new_point[0] = max(0, min(1, point[0])) |
| 388 | + warn = True |
| 389 | + if not (0 <= point[1] <= 1): |
| 390 | + new_point[1] = max(0, min(1, point[1])) |
| 391 | + warn = True |
| 392 | + values["points"][i] = tuple(new_point) |
| 393 | + |
| 394 | + if warn: |
| 395 | + logger.warning( |
| 396 | + "Polyline annotation has values outside of [0,1] range. Clipping them to [0,1]." |
| 397 | + ) |
| 398 | + return values |
| 399 | + |
343 | 400 | def to_numpy(self, _: Dict[str, int], width: int, height: int) -> np.ndarray: |
344 | 401 | polyline = [(round(x * width), round(y * height)) for x, y in self.points] |
345 | 402 | mask = Image.new("L", (width, height), 0) |
|
0 commit comments