|
1 | 1 | import json |
| 2 | +import logging |
2 | 3 | from abc import ABC, abstractmethod |
3 | 4 | from datetime import datetime |
4 | 5 | from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, TypedDict, Union |
|
20 | 21 |
|
21 | 22 | from ..utils.enums import LabelType |
22 | 23 |
|
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
23 | 26 | KeypointVisibility: TypeAlias = Literal[0, 1, 2] |
24 | 27 | NormalizedFloat: TypeAlias = Annotated[float, Field(ge=0, le=1)] |
25 | 28 | """C{NormalizedFloat} is a float that is restricted to the range [0, 1].""" |
@@ -137,6 +140,25 @@ class BBoxAnnotation(Annotation): |
137 | 140 |
|
138 | 141 | _label_type = LabelType.BOUNDINGBOX |
139 | 142 |
|
| 143 | + @model_validator(mode="before") |
| 144 | + @classmethod |
| 145 | + def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: |
| 146 | + warn = False |
| 147 | + for key in ["x", "y", "w", "h"]: |
| 148 | + if values[key] < -2 or values[key] > 2: |
| 149 | + raise ValueError( |
| 150 | + "BBox annotation has value outside of automatic clipping range ([-2, 2]). " |
| 151 | + "Values should be normalized based on image size to range [0, 1]." |
| 152 | + ) |
| 153 | + if not (0 <= values[key] <= 1): |
| 154 | + warn = True |
| 155 | + values[key] = max(0, min(1, values[key])) |
| 156 | + if warn: |
| 157 | + logger.warning( |
| 158 | + "BBox annotation has values outside of [0, 1] range. Clipping them to [0, 1]." |
| 159 | + ) |
| 160 | + return values |
| 161 | + |
140 | 162 | def to_numpy(self, class_mapping: Dict[str, int]) -> np.ndarray: |
141 | 163 | class_ = class_mapping.get(self.class_, 0) |
142 | 164 | return np.array([class_, self.x, self.y, self.w, self.h]) |
@@ -170,6 +192,33 @@ class KeypointAnnotation(Annotation): |
170 | 192 |
|
171 | 193 | _label_type = LabelType.KEYPOINTS |
172 | 194 |
|
| 195 | + @model_validator(mode="before") |
| 196 | + @classmethod |
| 197 | + def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: |
| 198 | + warn = False |
| 199 | + for i, keypoint in enumerate(values["keypoints"]): |
| 200 | + if (keypoint[0] < -2 or keypoint[0] > 2) or ( |
| 201 | + keypoint[1] < -2 or keypoint[1] > 2 |
| 202 | + ): |
| 203 | + raise ValueError( |
| 204 | + "Keypoint annotation has value outside of automatic clipping range ([-2, 2]). " |
| 205 | + "Values should be normalized based on image size to range [0, 1]." |
| 206 | + ) |
| 207 | + new_keypoint = list(keypoint) |
| 208 | + if not (0 <= keypoint[0] <= 1): |
| 209 | + new_keypoint[0] = max(0, min(1, keypoint[0])) |
| 210 | + warn = True |
| 211 | + if not (0 <= keypoint[1] <= 1): |
| 212 | + new_keypoint[1] = max(0, min(1, keypoint[1])) |
| 213 | + warn = True |
| 214 | + values["keypoints"][i] = tuple(new_keypoint) |
| 215 | + |
| 216 | + if warn: |
| 217 | + logger.warning( |
| 218 | + "Keypoint annotation has values outside of [0, 1] range. Clipping them to [0, 1]." |
| 219 | + ) |
| 220 | + return values |
| 221 | + |
173 | 222 | def to_numpy(self, class_mapping: Dict[str, int]) -> np.ndarray: |
174 | 223 | class_ = class_mapping.get(self.class_, 0) |
175 | 224 | kps = np.array(self.keypoints).reshape((-1, 3)).astype(np.float32) |
@@ -340,6 +389,31 @@ class PolylineSegmentationAnnotation(SegmentationAnnotation): |
340 | 389 |
|
341 | 390 | points: List[Tuple[NormalizedFloat, NormalizedFloat]] = Field(min_length=3) |
342 | 391 |
|
| 392 | + @model_validator(mode="before") |
| 393 | + @classmethod |
| 394 | + def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: |
| 395 | + warn = False |
| 396 | + for i, point in enumerate(values["points"]): |
| 397 | + if (point[0] < -2 or point[0] > 2) or (point[1] < -2 or point[1] > 2): |
| 398 | + raise ValueError( |
| 399 | + "Polyline annotation has value outside of automatic clipping range ([-2, 2]). " |
| 400 | + "Values should be normalized based on image size to range [0, 1]." |
| 401 | + ) |
| 402 | + new_point = list(point) |
| 403 | + if not (0 <= point[0] <= 1): |
| 404 | + new_point[0] = max(0, min(1, point[0])) |
| 405 | + warn = True |
| 406 | + if not (0 <= point[1] <= 1): |
| 407 | + new_point[1] = max(0, min(1, point[1])) |
| 408 | + warn = True |
| 409 | + values["points"][i] = tuple(new_point) |
| 410 | + |
| 411 | + if warn: |
| 412 | + logger.warning( |
| 413 | + "Polyline annotation has values outside of [0, 1] range. Clipping them to [0, 1]." |
| 414 | + ) |
| 415 | + return values |
| 416 | + |
343 | 417 | def to_numpy(self, _: Dict[str, int], width: int, height: int) -> np.ndarray: |
344 | 418 | polyline = [(round(x * width), round(y * height)) for x, y in self.points] |
345 | 419 | mask = Image.new("L", (width, height), 0) |
|
0 commit comments