Skip to content

Commit 21807bc

Browse files
committed
yolov8 instance segmentation and keypoints and more tests
1 parent b144881 commit 21807bc

File tree

8 files changed

+594
-78
lines changed

8 files changed

+594
-78
lines changed

luxonis_ml/data/datasets/luxonis_dataset.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
YoloV4Exporter,
3232
YoloV6Exporter,
3333
YoloV8Exporter,
34+
YoloV8InstanceSegmentationExporter,
35+
YoloV8KeypointsExporter,
3436
)
3537
from luxonis_ml.data.exporters.exporter_utils import (
3638
ExporterSpec,
@@ -1528,11 +1530,12 @@ def export(
15281530
"skeletons": getattr(self.metadata, "skeletons", None),
15291531
},
15301532
),
1531-
DatasetType.YOLOV8: ExporterSpec(
1532-
YoloV8Exporter,
1533-
{
1534-
"skeletons": getattr(self.metadata, "skeletons", None),
1535-
},
1533+
DatasetType.YOLOV8: ExporterSpec(YoloV8Exporter, {}),
1534+
DatasetType.YOLOV8INSTANCESEGMENTATION: ExporterSpec(
1535+
YoloV8InstanceSegmentationExporter, {}
1536+
),
1537+
DatasetType.YOLOV8KEYPOINTS: ExporterSpec(
1538+
YoloV8KeypointsExporter, {}
15361539
),
15371540
DatasetType.YOLOV6: ExporterSpec(YoloV6Exporter, {}),
15381541
DatasetType.YOLOV4: ExporterSpec(YoloV4Exporter, {}),

luxonis_ml/data/exporters/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
from .yolov4_exporter import YoloV4Exporter
1414
from .yolov6_exporter import YoloV6Exporter
1515
from .yolov8_exporter import YoloV8Exporter
16+
from .yolov8_instance_segmentation_exporter import (
17+
YoloV8InstanceSegmentationExporter,
18+
)
19+
from .yolov8_keypoints_exporter import YoloV8KeypointsExporter
1620

1721
__all__ = [
1822
"BaseExporter",
@@ -28,4 +32,6 @@
2832
"YoloV4Exporter",
2933
"YoloV6Exporter",
3034
"YoloV8Exporter",
35+
"YoloV8InstanceSegmentationExporter",
36+
"YoloV8KeypointsExporter",
3137
]

luxonis_ml/data/exporters/exporter_utils.py

Lines changed: 39 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import polars as pl
99
from loguru import logger
10-
from pycocotools import mask
10+
from pycocotools import mask as maskUtils
1111

1212
if TYPE_CHECKING:
1313
from luxonis_ml.data.datasets.luxonis_dataset import LuxonisDataset
@@ -181,55 +181,6 @@ def get_single_skeleton(
181181
skeleton_1_based = [[a + 1, b + 1] for a, b in edges]
182182
return labels, skeleton_1_based
183183

184-
@staticmethod
185-
def rle_to_yolo_polygon(rle: str, height: int, width: int) -> list:
186-
# Decode RLE to binary mask
187-
m = mask.decode({"size": [height, width], "counts": rle})
188-
189-
# Each contour = one polygon
190-
contours, _ = cv2.findContours(
191-
m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
192-
)
193-
194-
polygons = []
195-
for contour in contours:
196-
contour = contour.squeeze()
197-
if len(contour.shape) != 2:
198-
continue
199-
polygon = []
200-
for x, y in contour:
201-
polygon.extend([x / width, y / height])
202-
polygons.append(polygon)
203-
204-
return polygons
205-
206-
@staticmethod
207-
def _bbox_from_poly(
208-
coords: list[float],
209-
) -> tuple[float, float, float, float]:
210-
xs = coords[0::2]
211-
ys = coords[1::2]
212-
x_min, x_max = min(xs), max(xs)
213-
y_min, y_max = min(ys), max(ys)
214-
return x_min, y_min, (x_max - x_min), (y_max - y_min)
215-
216-
@staticmethod
217-
def _iou_xywh(
218-
a: tuple[float, float, float, float],
219-
b: tuple[float, float, float, float],
220-
) -> float:
221-
ax, ay, aw, ah = a
222-
bx, by, bw, bh = b
223-
ax2, ay2 = ax + aw, ay + ah
224-
bx2, by2 = bx + bw, by + bh
225-
inter_w = max(0.0, min(ax2, bx2) - max(ax, bx))
226-
inter_h = max(0.0, min(ay2, by2) - max(ay, by))
227-
inter = inter_w * inter_h
228-
if inter <= 0.0:
229-
return 0.0
230-
union = aw * ah + bw * bh - inter
231-
return inter / union if union > 0.0 else 0.0
232-
233184
@staticmethod
234185
def decode_rle_with_pycoco(ann: dict[str, Any]) -> np.ndarray:
235186
h = int(ann["height"])
@@ -239,14 +190,43 @@ def decode_rle_with_pycoco(ann: dict[str, Any]) -> np.ndarray:
239190
# pycocotools expects an RLE object with 'size' and 'counts'
240191
rle = {"size": [h, w], "counts": counts.encode("utf-8")}
241192

242-
m = mask.decode(rle) # type: ignore[arg-type]
193+
m = maskUtils.decode(rle) # type: ignore[arg-type]
243194
return np.array(m, dtype=np.uint8, order="C")
244195

245-
def _normalize(
246-
self, xs: list[float], ys: list[float], w: float, h: float
247-
) -> list[float]:
248-
out: list[float] = []
249-
for x, y in zip(xs, ys, strict=True):
250-
out.append(max(0.0, min(1.0, x / w)))
251-
out.append(max(0.0, min(1.0, y / h)))
252-
return out
196+
@staticmethod
197+
def annotation_to_polygons(
198+
ann: dict[str, Any], file_path: Path
199+
) -> list[list[tuple[float, float]]]:
200+
polygons: list[list[tuple[float, float]]] = []
201+
202+
# COCO RLE -> decode to mask -> contours -> polygons
203+
if "counts" in ann:
204+
H = int(ann["height"])
205+
W = int(ann["width"])
206+
rle = {"size": [H, W], "counts": ann["counts"]}
207+
try:
208+
mask = maskUtils.decode(rle) # type: ignore
209+
if mask.ndim == 3:
210+
mask = mask[:, :, 0]
211+
mask = (mask > 0).astype(np.uint8)
212+
213+
contours, _ = cv2.findContours(
214+
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
215+
)
216+
for cnt in contours:
217+
if len(cnt) < 3:
218+
continue
219+
cnt = cnt.squeeze(1)
220+
poly = [
221+
(float(x) / W, float(y) / H) for x, y in cnt.tolist()
222+
]
223+
if len(poly) >= 3:
224+
polygons.append(poly)
225+
except Exception:
226+
logger.warning(
227+
"Failed to decode COCO RLE; skipping this instance.",
228+
RuntimeWarning,
229+
)
230+
return polygons
231+
232+
return polygons

luxonis_ml/data/exporters/yolov8_exporter.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,16 @@ def __init__(
1717
dataset_identifier: str,
1818
output_path: Path,
1919
max_partition_size_gb: float | None,
20-
*,
21-
skeletons: dict[str, Any] | None = None,
2220
):
2321
super().__init__(
2422
dataset_identifier, output_path, max_partition_size_gb
2523
)
2624
self.class_to_id: dict[str, int] = {}
2725
self.class_names: list[str] = []
28-
self.skeletons = skeletons # for later keypoint export implementation
2926

30-
# v8 uses "val"
3127
def get_split_names(self) -> dict[str, str]:
3228
return {"train": "train", "val": "val", "test": "test"}
3329

34-
# v8 dataset.yaml
3530
def _yaml_filename(self) -> str:
3631
return "dataset.yaml"
3732

@@ -44,7 +39,6 @@ def transform(self, prepared_ldf: PreparedLDF) -> None:
4439
prepared_ldf, self.supported_ann_types()
4540
)
4641

47-
# dict[split][image_name] -> list of tuples: (cid, ...) bbox or polygon
4842
annotation_splits: dict[str, dict[str, list[tuple]]] = {
4943
k: {} for k in self.get_split_names().values()
5044
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from pathlib import Path
5+
from typing import Any, cast
6+
7+
from luxonis_ml.data.exporters.exporter_utils import ExporterUtils, PreparedLDF
8+
9+
from .base_exporter import BaseExporter
10+
11+
12+
class YoloV8InstanceSegmentationExporter(BaseExporter):
13+
def __init__(
14+
self,
15+
dataset_identifier: str,
16+
output_path: Path,
17+
max_partition_size_gb: float | None,
18+
):
19+
super().__init__(
20+
dataset_identifier, output_path, max_partition_size_gb
21+
)
22+
self.class_to_id: dict[str, int] = {}
23+
self.class_names: list[str] = []
24+
25+
def get_split_names(self) -> dict[str, str]:
26+
return {"train": "train", "val": "val", "test": "test"}
27+
28+
def _yaml_filename(self) -> str:
29+
return "dataset.yaml"
30+
31+
def supported_ann_types(self) -> list[str]:
32+
return ["instance_segmentation"]
33+
34+
def transform(self, prepared_ldf: PreparedLDF) -> None:
35+
ExporterUtils.check_group_file_correspondence(prepared_ldf)
36+
ExporterUtils.exporter_specific_annotation_warning(
37+
prepared_ldf, self.supported_ann_types()
38+
)
39+
40+
annotation_splits: dict[str, dict[str, list[str]]] = {
41+
k: {} for k in self.get_split_names().values()
42+
}
43+
44+
df = prepared_ldf.processed_df
45+
grouped = df.group_by(["file", "group_id"], maintain_order=True)
46+
copied_files: set[Path] = set()
47+
48+
for key, group_df in grouped:
49+
file_name, group_id = cast(tuple[str, Any], key)
50+
logical_split = ExporterUtils.split_of_group(
51+
prepared_ldf, group_id
52+
)
53+
split = self.get_split_names()[logical_split]
54+
55+
file_path = Path(str(file_name))
56+
idx = self.image_indices.setdefault(
57+
file_path, len(self.image_indices)
58+
)
59+
new_name = f"{idx}{file_path.suffix}"
60+
61+
label_lines: list[str] = []
62+
63+
for row in group_df.iter_rows(named=True):
64+
ttype = row["task_type"]
65+
ann_str = row["annotation"]
66+
cname = row["class_name"]
67+
68+
if ann_str is None:
69+
continue
70+
if ttype != "instance_segmentation":
71+
continue
72+
73+
if cname and cname not in self.class_to_id:
74+
self.class_to_id[cname] = len(self.class_to_id)
75+
self.class_names.append(cname)
76+
if not cname or cname not in self.class_to_id:
77+
continue
78+
79+
ann = json.loads(ann_str)
80+
81+
cid = self.class_to_id[cname]
82+
polygons = ExporterUtils.annotation_to_polygons(ann, file_path)
83+
84+
for poly in polygons:
85+
if len(poly) < 3:
86+
continue
87+
parts = []
88+
for x, y in poly:
89+
x_ = 0.0 if x < 0 else 1.0 if x > 1 else x
90+
y_ = 0.0 if y < 0 else 1.0 if y > 1 else y
91+
parts.append(f"{x_:.12f} {y_:.12f}")
92+
line = f"{cid} " + " ".join(parts)
93+
label_lines.append(line)
94+
95+
annotation_splits[split][new_name] = label_lines
96+
97+
ann_size_estimate = sum(len(s) + 1 for s in label_lines)
98+
img_size = file_path.stat().st_size
99+
annotation_splits = self._maybe_roll_partition(
100+
annotation_splits, ann_size_estimate + img_size
101+
)
102+
103+
data_path = self._get_data_path(self.output_path, split, self.part)
104+
data_path.mkdir(parents=True, exist_ok=True)
105+
dest = data_path / new_name
106+
if file_path not in copied_files:
107+
copied_files.add(file_path)
108+
if dest != file_path:
109+
dest.write_bytes(file_path.read_bytes())
110+
self.current_size += img_size
111+
112+
self._dump_annotations(annotation_splits, self.output_path, self.part)
113+
114+
def _maybe_roll_partition(
115+
self,
116+
annotation_splits: dict[str, dict[str, list[str]]],
117+
additional_size: int,
118+
) -> dict[str, dict[str, list[str]]]:
119+
if (
120+
self.max_partition_size
121+
and self.part is not None
122+
and (self.current_size + additional_size) > self.max_partition_size
123+
):
124+
self._dump_annotations(
125+
annotation_splits, self.output_path, self.part
126+
)
127+
self.current_size = 0
128+
self.part += 1
129+
return {k: {} for k in self.get_split_names().values()}
130+
return annotation_splits
131+
132+
def _dump_annotations(
133+
self,
134+
annotation_splits: dict[str, dict[str, list[str]]],
135+
output_path: Path,
136+
part: int | None = None,
137+
) -> None:
138+
base = (
139+
output_path / f"{self.dataset_identifier}_part{part}"
140+
if part is not None
141+
else output_path / self.dataset_identifier
142+
)
143+
144+
for split_name in self.get_split_names().values():
145+
labels_dir = base / "labels" / split_name
146+
labels_dir.mkdir(parents=True, exist_ok=True)
147+
images_dir = base / "images" / split_name
148+
images_dir.mkdir(parents=True, exist_ok=True)
149+
150+
for img_name, lines in annotation_splits.get(
151+
split_name, {}
152+
).items():
153+
(labels_dir / f"{Path(img_name).stem}.txt").write_text(
154+
"\n".join(lines), encoding="utf-8"
155+
)
156+
157+
yaml_filename = self._yaml_filename()
158+
if yaml_filename:
159+
split_dirs = self.get_split_names()
160+
yaml_obj = {
161+
"train": str(Path("images") / split_dirs["train"]),
162+
"val": str(Path("images") / split_dirs["val"]),
163+
"test": str(Path("images") / split_dirs["test"]),
164+
"nc": len(self.class_names),
165+
"names": self.class_names,
166+
}
167+
(base / yaml_filename).write_text(
168+
self._to_yaml(yaml_obj), encoding="utf-8"
169+
)
170+
171+
def _get_data_path(
172+
self, output_path: Path, split: str, part: int | None = None
173+
) -> Path:
174+
base = (
175+
output_path / f"{self.dataset_identifier}_part{part}"
176+
if part is not None
177+
else output_path / self.dataset_identifier
178+
)
179+
return base / "images" / split
180+
181+
@staticmethod
182+
def _to_yaml(d: dict[str, Any]) -> str:
183+
lines: list[str] = []
184+
for k, v in d.items():
185+
lines.append(f"{k}: {v}")
186+
return "\n".join(lines) + "\n"

0 commit comments

Comments
 (0)