Skip to content

Commit d7015f5

Browse files
authored
Add class_order_per_task param to the loader (#265)
1 parent fbec717 commit d7015f5

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

luxonis_train/loaders/README.md

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@ The default loader used with `LuxonisTrain`. It can either load data from an alr
1212

1313
**Parameters:**
1414

15-
| Key | Type | Default value | Description |
16-
| --------------------- | --------------------------------------------------------------------------------------------------------- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
17-
| `dataset_name` | `str` | `None` | Name of the dataset to load. If not provided, the `dataset_dir` must be provided instead |
18-
| `dataset_dir` | `str` | `None` | Path to the directory containing the dataset. If not provided, the `dataset_name` must be provided instead. Can be a path to a local directory or a URL. The data can be in a zip archive. New `LuxonisDataset` will be created using data from this directory and saved under the provided `dataset_name` |
19-
| `dataset_type` | `Literal["coco", "voc", "darknet", "yolov6", "yolov4", "createml", "tfcsv", "clsdir", "segmask"] \| None` | `None` | Type of the dataset. If not provided, the type will be inferred from the directory structure |
20-
| `team_id` | `str \| None` | `None` | Optional unique team identifier for the cloud |
21-
| `bucket_storage` | `Literal["local", "s3", "gcs"]` | `"local"` | Type of the bucket storage |
22-
| `delete_existing` | `bool` | `False` | Whether to delete the existing dataset with the same name. Only relevant if `dataset_dir` is provided. Use if you want to reparse the directory in case the data changed |
23-
| `update_mode` | `Literal["all", "missing"]` | `all` | Select whether to download all remote dataset media files or only those missing locally |
24-
| `min_bbox_visibility` | `float` | `0.0` | Minimum fraction of the original bounding box that must remain visible after augmentation. |
25-
| `bbox_area_threshold` | `float` | `0.0004` | Minimum area threshold for bounding boxes to be considered valid (relative units in [0,1]). Boxes (and their related keypoints) with area below this will be filtered out. |
15+
| Key | Type | Default value | Description |
16+
| ---------------------- | --------------------------------------------------------------------------------------------------------- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
17+
| `dataset_name` | `str` | `None` | Name of the dataset to load. If not provided, the `dataset_dir` must be provided instead |
18+
| `dataset_dir` | `str` | `None` | Path to the directory containing the dataset. If not provided, the `dataset_name` must be provided instead. Can be a path to a local directory or a URL. The data can be in a zip archive. New `LuxonisDataset` will be created using data from this directory and saved under the provided `dataset_name` |
19+
| `dataset_type` | `Literal["coco", "voc", "darknet", "yolov6", "yolov4", "createml", "tfcsv", "clsdir", "segmask"] \| None` | `None` | Type of the dataset. If not provided, the type will be inferred from the directory structure |
20+
| `team_id` | `str \| None` | `None` | Optional unique team identifier for the cloud |
21+
| `bucket_storage` | `Literal["local", "s3", "gcs"]` | `"local"` | Type of the bucket storage |
22+
| `delete_existing` | `bool` | `False` | Whether to delete the existing dataset with the same name. Only relevant if `dataset_dir` is provided. Use if you want to reparse the directory in case the data changed |
23+
| `update_mode` | `Literal["all", "missing"]` | `all` | Select whether to download all remote dataset media files or only those missing locally |
24+
| `min_bbox_visibility` | `float` | `0.0` | Minimum fraction of the original bounding box that must remain visible after augmentation. |
25+
| `bbox_area_threshold` | `float` | `0.0004` | Minimum area threshold for bounding boxes to be considered valid (relative units in [0,1]). Boxes (and their related keypoints) with area below this will be filtered out. |
26+
| `class_order_per_task` | `dict[str, List[str]] \| None` | `None` | If provided, the classes for the specified tasks will be reordered in the dataset. |
2627

2728
**Data Shape Definitions:**
2829

luxonis_train/loaders/luxonis_loader_torch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
filter_task_names: list[str] | None = None,
3030
min_bbox_visibility: float = 0.0,
3131
bbox_area_threshold: float = 0.0004,
32+
class_order_per_task: dict[str, list[str]] | None = None,
3233
seed: int | None = None,
3334
**kwargs,
3435
):
@@ -81,6 +82,9 @@ def __init__(
8182
@type bbox_area_threshold: float
8283
@param bbox_area_threshold: Minimum area threshold for bounding boxes to be considered valid. In the range [0, 1].
8384
Default is 0.0004, which corresponds to a small area threshold to remove invalid bboxes and respective keypoints.
85+
@type class_order_per_task: dict[str, list[str]] | None
86+
@param class_order_per_task: Dictionary mapping task names to a list of class names.
87+
If provided, the classes for the specified tasks will be reordered.
8488
@type seed: Optional[int]
8589
@param seed: The random seed to use for the augmentations.
8690
"""
@@ -100,6 +104,8 @@ def __init__(
100104
bucket_type=bucket_type,
101105
bucket_storage=bucket_storage,
102106
)
107+
if class_order_per_task is not None:
108+
self.dataset.set_class_order_per_task(class_order_per_task)
103109
self.loader = LuxonisLoader(
104110
dataset=self.dataset,
105111
view=self.view,

0 commit comments

Comments
 (0)