Skip to content

Commit 59c90bc

Browse files
authored
Clone/Merge Specified Splits (#342)
1 parent aee6bf9 commit 59c90bc

File tree

3 files changed

+249
-9
lines changed

3 files changed

+249
-9
lines changed

luxonis_ml/data/__main__.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,19 @@ def clone(
819819
),
820820
] = True,
821821
bucket_storage: BucketStorage = bucket_option,
822+
splits_to_clone: Annotated[
823+
str | None,
824+
typer.Option(
825+
"--split",
826+
"-s",
827+
help=(
828+
"Comma separated list of split names to clone, "
829+
"e.g. `-s val,test` or just `-s train`. "
830+
"If omitted, clones all splits."
831+
),
832+
show_default=False,
833+
),
834+
] = None,
822835
team_id: Annotated[
823836
str | None,
824837
typer.Option(
@@ -843,10 +856,20 @@ def clone(
843856
):
844857
raise typer.Exit
845858

859+
if splits_to_clone:
860+
split_list = [
861+
s.strip() for s in splits_to_clone.split(",") if s.strip()
862+
]
863+
else:
864+
split_list = None
865+
846866
print(f"Cloning dataset '{name}' to '{new_name}'...")
847867
dataset = LuxonisDataset(name, bucket_storage=bucket_storage)
848868
dataset.clone(
849-
new_dataset_name=new_name, push_to_cloud=push_to_cloud, team_id=team_id
869+
new_dataset_name=new_name,
870+
push_to_cloud=push_to_cloud,
871+
splits_to_clone=split_list,
872+
team_id=team_id,
850873
)
851874
print(f"[green]Dataset '{name}' successfully cloned to '{new_name}'.")
852875

@@ -879,7 +902,29 @@ def merge(
879902
show_default=False,
880903
),
881904
] = None,
905+
splits_to_merge: Annotated[
906+
str | None,
907+
typer.Option(
908+
"--split",
909+
"-s",
910+
help=(
911+
"Comma separated list of split names to merge, "
912+
"e.g. `-s val,test` or just `-s train`. "
913+
"If omitted, merges all splits."
914+
),
915+
show_default=False,
916+
),
917+
] = None,
882918
bucket_storage: BucketStorage = bucket_option,
919+
team_id: Annotated[
920+
str | None,
921+
typer.Option(
922+
"--team-id",
923+
"-t",
924+
help="Team ID to use for the new dataset. If not provided, the dataset's current team ID will be used.",
925+
show_default=False,
926+
),
927+
] = None,
883928
):
884929
"""Merge two datasets stored in the same type of bucket."""
885930
check_exists(source_name, bucket_storage)
@@ -900,6 +945,13 @@ def merge(
900945
):
901946
raise typer.Exit
902947

948+
if splits_to_merge:
949+
split_list = [
950+
s.strip() for s in splits_to_merge.split(",") if s.strip()
951+
]
952+
else:
953+
split_list = None
954+
903955
source_dataset = LuxonisDataset(source_name, bucket_storage=bucket_storage)
904956
target_dataset = LuxonisDataset(target_name, bucket_storage=bucket_storage)
905957

@@ -909,7 +961,11 @@ def merge(
909961
)
910962

911963
_ = target_dataset.merge_with(
912-
source_dataset, inplace=inplace, new_dataset_name=new_name
964+
source_dataset,
965+
inplace=inplace,
966+
new_dataset_name=new_name,
967+
splits_to_merge=split_list,
968+
team_id=team_id,
913969
)
914970

915971
if inplace:

luxonis_ml/data/datasets/luxonis_dataset.py

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from concurrent.futures import ThreadPoolExecutor
88
from contextlib import suppress
99
from functools import cached_property
10+
from os import PathLike
1011
from pathlib import Path, PurePosixPath
1112
from typing import Any, Literal, overload
1213

@@ -273,6 +274,7 @@ def clone(
273274
self,
274275
new_dataset_name: str,
275276
push_to_cloud: bool = True,
277+
splits_to_clone: list[str] | None = None,
276278
team_id: str | None = None,
277279
) -> "LuxonisDataset":
278280
"""Create a new LuxonisDataset that is a local copy of the
@@ -284,6 +286,11 @@ def clone(
284286
@type push_to_cloud: bool
285287
@param push_to_cloud: Whether to push the new dataset to the
286288
cloud. Only if the current dataset is remote.
289+
@param splits_to_clone: list[str] | None
290+
@type splits_to_clone: Optional list of split names to clone. If
291+
None, all data will be cloned.
292+
@type team_id: str | None
293+
@param team_id: Optional team identifier.
287294
"""
288295
if team_id is None:
289296
team_id = self.team_id
@@ -302,10 +309,33 @@ def clone(
302309

303310
new_dataset_path = Path(new_dataset.local_path)
304311
new_dataset_path.mkdir(parents=True, exist_ok=True)
312+
313+
if splits_to_clone is not None:
314+
df_self = self._load_df_offline(raise_when_empty=True)
315+
splits_self = self._load_splits(self.metadata_path)
316+
uuids_to_clone = {
317+
uid
318+
for split in splits_to_clone
319+
for uid in splits_self.get(split, [])
320+
}
321+
df_self = df_self.filter(df_self["uuid"].is_in(uuids_to_clone))
322+
splits_self = {
323+
k: v for k, v in splits_self.items() if k in splits_to_clone
324+
}
325+
305326
shutil.copytree(
306-
self.local_path, new_dataset.local_path, dirs_exist_ok=True
327+
self.local_path,
328+
new_dataset.local_path,
329+
dirs_exist_ok=True,
330+
ignore=lambda d, n: self._ignore_files_not_in_uuid_set(
331+
d, n, uuids_to_clone if splits_to_clone else set()
332+
),
307333
)
308334

335+
if splits_to_clone is not None:
336+
new_dataset._save_df_offline(df_self)
337+
new_dataset._save_splits(splits_self)
338+
309339
new_dataset._init_paths()
310340
new_dataset._metadata = self._get_metadata()
311341

@@ -331,6 +361,8 @@ def merge_with(
331361
other: "LuxonisDataset",
332362
inplace: bool = True,
333363
new_dataset_name: str | None = None,
364+
splits_to_merge: list[str] | None = None,
365+
team_id: str | None = None,
334366
) -> "LuxonisDataset":
335367
"""Merge all data from `other` LuxonisDataset into the current
336368
dataset (in-place or in a new dataset).
@@ -343,6 +375,10 @@ def merge_with(
343375
@type new_dataset_name: str
344376
@param new_dataset_name: The name of the new dataset to create
345377
if inplace is False.
378+
@type splits_to_merge: list[str] | None
379+
@param splits_to_merge: Optional list of split names to merge.
380+
@type team_id: str | None
381+
@param team_id: Optional team identifier.
346382
"""
347383
if inplace:
348384
target_dataset = self
@@ -351,7 +387,9 @@ def merge_with(
351387
raise ValueError(
352388
"Cannot merge datasets with different bucket storage types."
353389
)
354-
target_dataset = self.clone(new_dataset_name, push_to_cloud=False)
390+
target_dataset = self.clone(
391+
new_dataset_name, push_to_cloud=False, team_id=team_id
392+
)
355393
else:
356394
raise ValueError(
357395
"You must specify a name for the new dataset "
@@ -374,13 +412,22 @@ def merge_with(
374412
~df_other["uuid"].is_in(duplicate_uuids)
375413
)
376414

415+
splits_self = self._load_splits(self.metadata_path)
416+
splits_other = self._load_splits(other.metadata_path)
417+
if splits_to_merge is not None:
418+
uuids_to_merge = {
419+
uuid
420+
for split_name in splits_to_merge
421+
for uuid in splits_other.get(split_name, [])
422+
}
423+
df_other = df_other.filter(df_other["uuid"].is_in(uuids_to_merge))
424+
splits_other = {
425+
k: v for k, v in splits_other.items() if k in splits_to_merge
426+
}
427+
377428
df_merged = pl.concat([df_self, df_other])
378429
target_dataset._save_df_offline(df_merged)
379430

380-
splits_self = self._load_splits(self.metadata_path)
381-
splits_other = self._load_splits(
382-
other.metadata_path
383-
) # dict of split names to list of uuids
384431
splits_other = {
385432
split_name: [uuid for uuid in uuids if uuid not in duplicate_uuids]
386433
for split_name, uuids in splits_other.items()
@@ -390,7 +437,12 @@ def merge_with(
390437

391438
if self.is_remote:
392439
shutil.copytree(
393-
other.media_path, target_dataset.media_path, dirs_exist_ok=True
440+
other.media_path,
441+
target_dataset.media_path,
442+
dirs_exist_ok=True,
443+
ignore=lambda d, n: self._ignore_files_not_in_uuid_set(
444+
d, n, uuids_to_merge if splits_to_merge else set()
445+
),
394446
)
395447
target_dataset.push_to_cloud(
396448
bucket_storage=target_dataset.bucket_storage,
@@ -418,6 +470,21 @@ def _load_splits(self, path: Path) -> dict[str, list[str]]:
418470
with open(splits_path) as f:
419471
return json.load(f)
420472

473+
def _ignore_files_not_in_uuid_set(
474+
self,
475+
dir_path: PathLike[str] | str,
476+
names: list[str],
477+
uuids_to_keep: set[str],
478+
) -> set[str]:
479+
if not uuids_to_keep:
480+
return set()
481+
ignored: set[str] = set()
482+
for name in names:
483+
full = Path(dir_path) / name
484+
if full.is_file() and full.stem not in uuids_to_keep:
485+
ignored.add(name)
486+
return ignored
487+
421488
def _merge_splits(
422489
self,
423490
splits_self: dict[str, list[str]],

tests/test_data/test_dataset.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,123 @@ def generator2() -> DatasetIterator:
770770
)
771771

772772

773+
@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
774+
def test_merge_datasets_specific_split(
775+
bucket_storage: BucketStorage,
776+
dataset_name: str,
777+
tempdir: Path,
778+
):
779+
dataset_name = f"{dataset_name}_{bucket_storage.value}"
780+
dataset1_name = f"{dataset_name}_1"
781+
dataset2_name = f"{dataset_name}_2"
782+
783+
def generator1() -> DatasetIterator:
784+
for i in range(3):
785+
img = create_image(i, tempdir)
786+
yield {
787+
"file": img,
788+
"annotation": {
789+
"class": "person",
790+
"boundingbox": {"x": 0.1, "y": 0.1, "w": 0.1, "h": 0.1},
791+
},
792+
}
793+
794+
def generator2() -> DatasetIterator:
795+
for i in range(3, 6):
796+
img = create_image(i, tempdir)
797+
yield {
798+
"file": img,
799+
"annotation": {
800+
"class": "dog",
801+
"boundingbox": {"x": 0.2, "y": 0.2, "w": 0.2, "h": 0.2},
802+
},
803+
}
804+
805+
dataset1 = create_dataset(
806+
dataset1_name,
807+
generator1(),
808+
bucket_storage,
809+
splits={"train": 0.6, "val": 0.4},
810+
)
811+
812+
dataset2 = create_dataset(
813+
dataset2_name,
814+
generator2(),
815+
bucket_storage,
816+
splits={"train": 0.6, "val": 0.4},
817+
)
818+
819+
merged_dataset = dataset1.merge_with(
820+
dataset2,
821+
inplace=False,
822+
new_dataset_name=f"{dataset1_name}_{dataset2_name}_merged",
823+
splits_to_merge=["train"],
824+
)
825+
826+
merged_stats = merged_dataset.get_statistics()
827+
assert {
828+
(item["count"], item["class_name"])
829+
for item in merged_stats["class_distributions"][""]["boundingbox"]
830+
} == {(3, "person"), (2, "dog")}
831+
merged_splits = merged_dataset.get_splits()
832+
dataset1_splits = dataset1.get_splits()
833+
dataset2_splits = dataset2.get_splits()
834+
assert merged_splits is not None
835+
assert dataset1_splits is not None
836+
assert dataset2_splits is not None
837+
assert set(merged_splits["train"]) == set(dataset1_splits["train"]) | set(
838+
dataset2_splits["train"]
839+
)
840+
assert set(merged_splits["val"]) == set(dataset1_splits["val"])
841+
842+
dataset1.delete_dataset(delete_local=True, delete_remote=True)
843+
dataset2.delete_dataset(delete_local=True, delete_remote=True)
844+
merged_dataset.delete_dataset(delete_local=True, delete_remote=True)
845+
846+
847+
@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
848+
def test_clone_dataset_specific_split(
849+
bucket_storage: BucketStorage,
850+
dataset_name: str,
851+
tempdir: Path,
852+
):
853+
def generator() -> DatasetIterator:
854+
for i in range(3):
855+
img = create_image(i, tempdir)
856+
yield {
857+
"file": img,
858+
"annotation": {
859+
"class": "person",
860+
"boundingbox": {"x": 0.1, "y": 0.1, "w": 0.1, "h": 0.1},
861+
},
862+
}
863+
864+
dataset = create_dataset(
865+
dataset_name,
866+
generator(),
867+
bucket_storage,
868+
splits={"train": 0.6, "val": 0.4},
869+
)
870+
cloned_dataset = dataset.clone(
871+
new_dataset_name=f"{dataset_name}_cloned",
872+
splits_to_clone=["train"],
873+
)
874+
dataset_splits = dataset.get_splits()
875+
cloned_splits = cloned_dataset.get_splits()
876+
assert cloned_splits is not None
877+
assert dataset_splits is not None
878+
assert set(cloned_splits["train"]) == set(dataset_splits["train"])
879+
assert "val" not in cloned_splits
880+
881+
cloned_stats = cloned_dataset.get_statistics()
882+
assert {
883+
(item["count"], item["class_name"])
884+
for item in cloned_stats["class_distributions"][""]["boundingbox"]
885+
} == {(2, "person")}
886+
887+
cloned_dataset.delete_dataset(delete_local=True, delete_remote=True)
888+
889+
773890
@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
774891
def test_classes_per_task(dataset_name: str, tempdir: Path):
775892
def generator() -> DatasetIterator:

0 commit comments

Comments
 (0)