Skip to content

Commit aa9ee35

Browse files
authored
.make_splits() supports more flexible split definitions. (#352)
1 parent 3b52066 commit aa9ee35

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

luxonis_ml/data/datasets/luxonis_dataset.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,10 +1294,12 @@ def make_splits(
12941294
if definitions is not None:
12951295
n_files = sum(map(len, definitions.values()))
12961296
if n_files > len(self):
1297-
raise ValueError(
1297+
logger.warning(
12981298
"Dataset size is smaller than the total number of files in the definitions. "
1299-
f"Dataset size: {len(self)}, Definitions: {n_files}."
1299+
f"Dataset size: {len(self)}, Definitions: {n_files}. "
1300+
"Duplicate files will be filtered out and extra files in definitions will be ignored."
13001301
)
1302+
self.remove_duplicates()
13011303

13021304
splits_to_update: list[str] = []
13031305
new_splits: dict[str, list[str]] = {}
@@ -1363,12 +1365,19 @@ def make_splits(
13631365
raise TypeError(
13641366
"Must provide splits as a list of filepaths"
13651367
)
1366-
ids = [
1367-
find_filepath_group_id(
1368-
filepath, index, raise_on_missing=True
1368+
ids: list[str] = []
1369+
for filepath in filepaths:
1370+
group_id = find_filepath_group_id(
1371+
filepath, index, raise_on_missing=False
13691372
)
1370-
for filepath in filepaths
1371-
]
1373+
1374+
if group_id is None:
1375+
logger.warning(
1376+
f"No group ID found for '{filepath}' in definitions; skipping."
1377+
)
1378+
continue
1379+
ids.append(group_id)
1380+
13721381
new_splits[split] = list(set(ids))
13731382

13741383
for split, group_ids in new_splits.items():
@@ -1825,3 +1834,6 @@ def remove_duplicates(self) -> None:
18251834
remote_dir="annotations",
18261835
copy_contents=True,
18271836
)
1837+
logger.info(
1838+
"Successfully removed duplicate files and annotations from the dataset."
1839+
)

luxonis_ml/data/utils/data_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,22 @@ def warn_on_duplicates(df: pl.LazyFrame) -> None:
255255
logger.warning(
256256
f"UUID {item['uuid']} is the same for multiple files: {item['files']}"
257257
)
258+
logger.warning(
259+
"Duplicate files detected. "
260+
"To clean them up, call `dataset.remove_duplicates()` "
261+
"or run the CLI command `luxonis_ml data sanitize`."
262+
)
258263

259-
for item in duplicates_info["duplicate_annotations"]:
264+
if duplicates_info["duplicate_annotations"]:
265+
for item in duplicates_info["duplicate_annotations"]:
266+
logger.warning(
267+
f"File '{item['file_name']}' of task '{item['task_name']}' has the "
268+
f"same '{item['task_type']}' annotation '{item['annotation']}' repeated {item['count']} times."
269+
)
260270
logger.warning(
261-
f"File '{item['file_name']}' of task '{item['task_name']}' has the "
262-
f"same '{item['task_type']}' annotation '{item['annotation']}' repeated {item['count']} times."
271+
"Duplicate annotations detected. "
272+
"To clean them up, call `dataset.remove_duplicates()` "
273+
"or run the CLI command `luxonis_ml data sanitize`."
263274
)
264275

265276

tests/test_data/test_dataset.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,6 @@ def generator(step: int = 15) -> DatasetIterator:
271271
with pytest.raises(ValueError, match="Ratios must sum to 1.0"):
272272
dataset.make_splits({"train": 1.5})
273273

274-
with pytest.raises(ValueError, match="Dataset size is smaller than"):
275-
dataset.make_splits(
276-
{split: defs * 2 for split, defs in splits.items()}
277-
)
278-
279274
dataset.add(generator(10))
280275
dataset.make_splits({"custom_split": 1.0})
281276
splits = dataset.get_splits()

0 commit comments

Comments
 (0)