Skip to content

Commit ae67e94

Browse files
committed
improved duplicate filtering
1 parent dc75e61 commit ae67e94

File tree

1 file changed

+52
-15
lines changed

1 file changed

+52
-15
lines changed

luxonis_ml/data/datasets/luxonis_dataset.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,30 @@ def _load_df_offline(
199199
dfs = [pl.read_parquet(file) for file in path.glob("*.parquet")]
200200
return pl.concat(dfs) if dfs else None
201201

202-
def _get_file_index(self) -> Optional[pl.DataFrame]:
202+
@overload
203+
def _get_file_index(
204+
self, lazy: Literal[False] = ...
205+
) -> Optional[pl.DataFrame]: ...
206+
207+
@overload
208+
def _get_file_index(
209+
self, lazy: Literal[True] = ...
210+
) -> Optional[pl.LazyFrame]: ...
211+
212+
def _get_file_index(
213+
self, lazy: bool = False
214+
) -> Optional[Union[pl.DataFrame, pl.LazyFrame]]:
203215
path = get_file(
204216
self.fs, "metadata/file_index.parquet", self.media_path
205217
)
206218
if path is not None and path.exists():
207-
return pl.read_parquet(path).select(
208-
pl.all().exclude("^__index_level_.*$")
209-
)
219+
if not lazy:
220+
df = pl.read_parquet(path)
221+
else:
222+
df = pl.scan_parquet(path)
223+
224+
return df.select(pl.all().exclude("^__index_level_.*$"))
225+
210226
return None
211227

212228
def _write_index(
@@ -514,7 +530,9 @@ def add(
514530

515531
batch_data: list[DatasetRecord] = []
516532

517-
classes_per_task: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet)
533+
classes_per_task: Dict[str, OrderedSet[str]] = defaultdict(
534+
lambda: OrderedSet([])
535+
)
518536
num_kpts_per_task: Dict[str, int] = {}
519537

520538
annotations_path = get_dir(
@@ -584,36 +602,55 @@ def add(
584602

585603
def _warn_on_duplicates(self) -> None:
586604
df = self._load_df_offline(lazy=True)
587-
if df is None:
605+
index_df = self._get_file_index(lazy=True)
606+
if df is None or index_df is None:
588607
return
608+
df = df.join(index_df, on="uuid").drop("file_right")
589609
# Warn on duplicate UUIDs
590610
duplicates_paired = (
591611
df.group_by("uuid")
592612
.agg(pl.col("file").n_unique().alias("file_count"))
593613
.filter(pl.col("file_count") > 1)
594614
.join(df, on="uuid")
595-
.select(["uuid", "file"])
615+
.select("uuid", "file")
596616
.unique()
597617
.group_by("uuid")
598-
.agg([pl.col("file").alias("files")])
618+
.agg(pl.col("file").alias("files"))
599619
.filter(pl.col("files").len() > 1)
620+
.collect()
600621
)
601-
duplicates_paired_df = duplicates_paired.collect()
602-
for uuid, files in duplicates_paired_df.iter_rows():
622+
for uuid, files in duplicates_paired.iter_rows():
603623
self.logger.warning(
604624
f"UUID: {uuid} has multiple file names: {files}"
605625
)
606626

607627
# Warn on duplicate annotations
608628
duplicate_annotation = (
609-
df.group_by(["file", "annotation"])
629+
df.group_by(
630+
"original_filepath",
631+
"task",
632+
"type",
633+
"annotation",
634+
"instance_id",
635+
)
610636
.agg(pl.len().alias("count"))
611637
.filter(pl.col("count") > 1)
612-
)
613-
duplicate_annotation_df = duplicate_annotation.collect()
614-
for file_name, annotation, _ in duplicate_annotation_df.iter_rows():
638+
.filter(pl.col("annotation") != "{}")
639+
.drop("instance_id")
640+
).collect()
641+
642+
for (
643+
file_name,
644+
task,
645+
type_,
646+
annotation,
647+
count,
648+
) in duplicate_annotation.iter_rows():
649+
if "RLE" in type_ or "Mask" in type_:
650+
annotation = "<binary mask>"
615651
self.logger.warning(
616-
f"File '{file_name}' has the same annotation '{annotation}' added multiple times."
652+
f"File '{file_name}' has the same '{type_}' annotation "
653+
f"'{annotation}' ({task=}) added {count} times."
617654
)
618655

619656
def get_splits(self) -> Optional[Dict[str, List[str]]]:

0 commit comments

Comments
 (0)