Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,10 +1294,12 @@ def make_splits(
if definitions is not None:
n_files = sum(map(len, definitions.values()))
if n_files > len(self):
raise ValueError(
logger.warning(
"Dataset size is smaller than the total number of files in the definitions. "
f"Dataset size: {len(self)}, Definitions: {n_files}."
f"Dataset size: {len(self)}, Definitions: {n_files}. "
"Duplicate files will be filtered out and extra files in definitions will be ignored."
)
self.remove_duplicates()

splits_to_update: list[str] = []
new_splits: dict[str, list[str]] = {}
Expand Down Expand Up @@ -1363,12 +1365,19 @@ def make_splits(
raise TypeError(
"Must provide splits as a list of filepaths"
)
ids = [
find_filepath_group_id(
filepath, index, raise_on_missing=True
ids: list[str] = []
for filepath in filepaths:
group_id = find_filepath_group_id(
filepath, index, raise_on_missing=False
)
for filepath in filepaths
]

if group_id is None:
logger.warning(
f"No group ID found for '{filepath}' in definitions; skipping."
)
continue
ids.append(group_id)

new_splits[split] = list(set(ids))

for split, group_ids in new_splits.items():
Expand Down Expand Up @@ -1825,3 +1834,6 @@ def remove_duplicates(self) -> None:
remote_dir="annotations",
copy_contents=True,
)
logger.info(
"Successfully removed duplicate files and annotations from the dataset."
)
17 changes: 14 additions & 3 deletions luxonis_ml/data/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,22 @@ def warn_on_duplicates(df: pl.LazyFrame) -> None:
logger.warning(
f"UUID {item['uuid']} is the same for multiple files: {item['files']}"
)
logger.warning(
"Duplicate files detected. "
"To clean them up, call `dataset.remove_duplicates()` "
"or run the CLI command `luxonis_ml data sanitize`."
)

for item in duplicates_info["duplicate_annotations"]:
if duplicates_info["duplicate_annotations"]:
for item in duplicates_info["duplicate_annotations"]:
logger.warning(
f"File '{item['file_name']}' of task '{item['task_name']}' has the "
f"same '{item['task_type']}' annotation '{item['annotation']}' repeated {item['count']} times."
)
logger.warning(
f"File '{item['file_name']}' of task '{item['task_name']}' has the "
f"same '{item['task_type']}' annotation '{item['annotation']}' repeated {item['count']} times."
"Duplicate annotations detected. "
"To clean them up, call `dataset.remove_duplicates()` "
"or run the CLI command `luxonis_ml data sanitize`."
)


Expand Down
5 changes: 0 additions & 5 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,6 @@ def generator(step: int = 15) -> DatasetIterator:
with pytest.raises(ValueError, match="Ratios must sum to 1.0"):
dataset.make_splits({"train": 1.5})

with pytest.raises(ValueError, match="Dataset size is smaller than"):
dataset.make_splits(
{split: defs * 2 for split, defs in splits.items()}
)

dataset.add(generator(10))
dataset.make_splits({"custom_split": 1.0})
splits = dataset.get_splits()
Expand Down
Loading