Skip to content

Commit 0b12b41

Browse files
authored
Add class reordering in dataset (#351)
1 parent 64d29f7 commit 0b12b41

File tree

4 files changed

+220
-1
lines changed

4 files changed

+220
-1
lines changed

luxonis_ml/data/datasets/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,22 @@ The `push_to_cloud()` method is used to upload a local dataset to the specified
9999
| `update_mode` | `UpdateMode` | `UpdateMode.MISSING` | Whether to always push (overwrite) the dataset’s media folder to the cloud or only upload missing files. |
100100
| `bucket_storage` | `BucketStorage` | Required | The cloud storage destination to which local media files should be uploaded (e.g., GCS, S3, Azure). |
101101

102+
### Setting Class Order per Task
103+
104+
The `set_class_order_per_task()` method allows you to define a specific ordering of classes for one or more tasks, without rewriting the dataset’s metadata.
105+
106+
#### Parameters
107+
108+
| Parameter | Type | Default | Description |
109+
| ---------------------- | ---------------------- | -------- | ------------------------------------------------------------------------------------------------------------ |
110+
| `class_order_per_task` | `dict[str, list[str]]` | Required | Mapping of task names to ordered lists of class names. Class names must exactly match the dataset’s classes. |
111+
112+
#### Persistence & Usage Notes
113+
114+
- **View-only ordering**: This method does *not* rewrite the dataset’s stored metadata (since `rewrite_metadata=False`). Instead, it applies the new class order as a view on the dataset object.
115+
- **New classes**: If new classes are added to the dataset, you must call `set_class_order_per_task()` again to include and order them.
116+
- **Loader initialization**: For `LuxonisLoader`, apply class ordering *before* passing the dataset into the loader to avoid unintended reordering during loader setup.
117+
102118
## In-Depth Explanation of luxonis-ml Dataset Storage
103119

104120
### File Structure

luxonis_ml/data/datasets/luxonis_dataset.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ def set_classes(
744744
self,
745745
classes: list[str] | dict[str, int],
746746
task: str | None = None,
747+
rewrite_metadata: bool = True,
747748
) -> None:
748749
if task is None:
749750
tasks = self.get_task_names()
@@ -753,7 +754,8 @@ def set_classes(
753754
for t in tasks:
754755
self._metadata.set_classes(classes, t)
755756

756-
self._write_metadata()
757+
if rewrite_metadata:
758+
self._write_metadata()
757759

758760
@override
759761
def get_classes(self) -> dict[str, dict[str, int]]:
@@ -1837,3 +1839,49 @@ def remove_duplicates(self) -> None:
18371839
logger.info(
18381840
"Successfully removed duplicate files and annotations from the dataset."
18391841
)
1842+
1843+
def set_class_order_per_task(
1844+
self, class_order_per_task: dict[str, list[str]]
1845+
) -> None:
1846+
"""Sets the class order for provided tasks. This method checks
1847+
if the provided class order matches the dataset's classes and
1848+
updates the dataset accordingly.
1849+
1850+
@type class_order_per_task: dict[str, list[str]]
1851+
@param class_order_per_task: A dictionary mapping task names to
1852+
a list of class names. The class names must match the
1853+
dataset's classes for the respective tasks.
1854+
@raises ValueError: If the task name is not found in the dataset
1855+
tasks or if the provided class names do not match the
1856+
dataset's classes.
1857+
"""
1858+
for task_name, task_classes in class_order_per_task.items():
1859+
if task_name not in self.get_tasks():
1860+
raise ValueError(
1861+
f"Task {task_name} not found in dataset tasks. "
1862+
f"Available tasks: {list(self.get_tasks().keys())}"
1863+
)
1864+
if set(task_classes) != set(self.get_classes()[task_name].keys()):
1865+
raise ValueError(
1866+
f"Classes for task {task_name} do not match "
1867+
f"the classes in the dataset. "
1868+
f"Expected: {set(self.get_classes()[task_name].keys())}, "
1869+
f"Got: {set(task_classes)}."
1870+
)
1871+
1872+
current_classes = list(self.get_classes()[task_name].keys())
1873+
if task_classes != current_classes:
1874+
logger.warning(
1875+
f"Reordering classes for task {task_name}. "
1876+
f"Original order: {current_classes}, "
1877+
f"New order: {task_classes}."
1878+
)
1879+
1880+
self.set_classes(
1881+
classes={
1882+
class_name: i
1883+
for i, class_name in enumerate(task_classes)
1884+
},
1885+
task=task_name,
1886+
rewrite_metadata=False,
1887+
)

luxonis_ml/data/loaders/luxonis_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __init__(
137137
self.width = width
138138

139139
self.dataset = dataset
140+
140141
self.sync_mode = self.dataset.is_remote
141142
self.keep_categorical_as_strings = keep_categorical_as_strings
142143
self.filter_task_names = filter_task_names

tests/test_data/test_dataset.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,3 +1118,157 @@ def generator(start: int, end: int) -> DatasetIterator:
11181118
)
11191119
== 6
11201120
)
1121+
1122+
1123+
def create_test_dataset_with_classes(
1124+
tempdir: Path, task_classes: dict[str, dict[str, int]]
1125+
) -> LuxonisDataset:
1126+
"""Helper function to create a test dataset with specific class
1127+
mappings."""
1128+
1129+
def generator() -> DatasetIterator:
1130+
for i in range(5):
1131+
img = create_image(i, tempdir)
1132+
yield {
1133+
"file": img,
1134+
"annotation": {
1135+
"class": list(task_classes["classification"].keys())[
1136+
i % len(task_classes["classification"])
1137+
],
1138+
"boundingbox": {
1139+
"x": 0.1 + i * 0.1,
1140+
"y": 0.1 + i * 0.1,
1141+
"w": 0.2,
1142+
"h": 0.2,
1143+
},
1144+
},
1145+
"task_name": "classification",
1146+
}
1147+
1148+
dataset = LuxonisDataset(
1149+
"test_class_order",
1150+
delete_local=True,
1151+
delete_remote=True,
1152+
bucket_storage=BucketStorage.LOCAL,
1153+
).add(generator())
1154+
1155+
# Set the classes for the dataset
1156+
for task_name, classes in task_classes.items():
1157+
dataset.set_classes(classes, task=task_name)
1158+
1159+
dataset.make_splits(ratios=(1, 0, 0))
1160+
return dataset
1161+
1162+
1163+
def test_class_order_per_task_valid_reordering(tempdir: Path):
1164+
"""Test valid class reordering for a task."""
1165+
original_classes = {"classification": {"cat": 0, "dog": 1, "bird": 2}}
1166+
1167+
dataset = create_test_dataset_with_classes(tempdir, original_classes)
1168+
1169+
# Define new class order
1170+
class_order_per_task = {"classification": ["dog", "bird", "cat"]}
1171+
dataset.set_class_order_per_task(class_order_per_task)
1172+
1173+
# Verify that classes were reordered
1174+
expected_classes = {"dog": 0, "bird": 1, "cat": 2}
1175+
assert dataset.get_classes()["classification"] == expected_classes
1176+
1177+
1178+
def test_class_order_per_task_multiple_tasks(tempdir: Path):
1179+
"""Test class reordering for multiple tasks."""
1180+
original_classes = {
1181+
"classification": {"cat": 0, "dog": 1, "bird": 2},
1182+
"detection": {"person": 0, "car": 1, "bike": 2},
1183+
}
1184+
1185+
# Create a more complex dataset with multiple tasks
1186+
def generator() -> DatasetIterator:
1187+
for i in range(5):
1188+
img = create_image(i, tempdir)
1189+
yield {
1190+
"file": img,
1191+
"annotation": {
1192+
"class": list(original_classes["classification"].keys())[
1193+
i % 3
1194+
],
1195+
},
1196+
"task_name": "classification",
1197+
}
1198+
1199+
for i in range(5, 10):
1200+
img = create_image(i, tempdir)
1201+
yield {
1202+
"file": img,
1203+
"annotation": {
1204+
"class": list(original_classes["detection"].keys())[
1205+
(i - 5) % 3
1206+
],
1207+
},
1208+
"task_name": "classification_1",
1209+
}
1210+
1211+
dataset = LuxonisDataset(
1212+
"test_multi_task_class_order",
1213+
delete_local=True,
1214+
delete_remote=True,
1215+
bucket_storage=BucketStorage.LOCAL,
1216+
).add(generator())
1217+
1218+
for task_name, classes in original_classes.items():
1219+
dataset.set_classes(classes, task=task_name)
1220+
1221+
dataset.make_splits(ratios=(1, 0, 0))
1222+
1223+
# Define new class orders for both tasks
1224+
class_order_per_task = {
1225+
"classification": ["bird", "cat", "dog"],
1226+
"classification_1": ["bike", "person", "car"],
1227+
}
1228+
1229+
dataset.set_class_order_per_task(class_order_per_task)
1230+
1231+
# Verify both tasks were reordered correctly
1232+
expected_classification = {"bird": 0, "cat": 1, "dog": 2}
1233+
expected_detection = {"bike": 0, "person": 1, "car": 2}
1234+
1235+
assert dataset.get_classes()["classification"] == expected_classification
1236+
assert dataset.get_classes()["classification_1"] == expected_detection
1237+
1238+
1239+
def test_class_order_per_task_invalid_task_name(tempdir: Path):
1240+
"""Test error when providing an invalid task name."""
1241+
original_classes = {"classification": {"cat": 0, "dog": 1, "bird": 2}}
1242+
1243+
dataset = create_test_dataset_with_classes(tempdir, original_classes)
1244+
1245+
# Define class order for non-existent task
1246+
class_order_per_task = {"invalid_task": ["cat", "dog", "bird"]}
1247+
1248+
with pytest.raises(
1249+
ValueError,
1250+
match=r"Task invalid_task not found in dataset tasks\. Available tasks: \['classification'\]",
1251+
):
1252+
dataset.set_class_order_per_task(class_order_per_task)
1253+
1254+
1255+
def test_class_order_per_task_mismatched_classes(tempdir: Path):
1256+
"""Test error when provided classes don't match dataset classes."""
1257+
original_classes = {"classification": {"cat": 0, "dog": 1, "bird": 2}}
1258+
1259+
dataset = create_test_dataset_with_classes(tempdir, original_classes)
1260+
1261+
# Define class order with wrong class names
1262+
class_order_per_task = {
1263+
"classification": [
1264+
"cat",
1265+
"dog",
1266+
"fish",
1267+
] # "fish" is not in original classes
1268+
}
1269+
1270+
with pytest.raises(
1271+
ValueError,
1272+
match=r"Classes for task classification do not match the classes in the dataset.",
1273+
):
1274+
dataset.set_class_order_per_task(class_order_per_task)

0 commit comments

Comments
 (0)