Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions luxonis_ml/data/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,19 @@ def clone(
),
] = True,
bucket_storage: BucketStorage = bucket_option,
splits_to_clone: Annotated[
str | None,
typer.Option(
"--split",
"-s",
help=(
"Comma separated list of split names to clone, "
"e.g. `-s val,test` or just `-s train`. "
"If omitted, clones all splits."
),
show_default=False,
),
] = None,
team_id: Annotated[
str | None,
typer.Option(
Expand All @@ -856,10 +869,20 @@ def clone(
):
raise typer.Exit

if splits_to_clone:
split_list = [
s.strip() for s in splits_to_clone.split(",") if s.strip()
]
else:
split_list = None

print(f"Cloning dataset '{name}' to '{new_name}'...")
dataset = LuxonisDataset(name, bucket_storage=bucket_storage)
dataset.clone(
new_dataset_name=new_name, push_to_cloud=push_to_cloud, team_id=team_id
new_dataset_name=new_name,
push_to_cloud=push_to_cloud,
splits_to_clone=split_list,
team_id=team_id,
)
print(f"[green]Dataset '{name}' successfully cloned to '{new_name}'.")

Expand Down Expand Up @@ -892,7 +915,29 @@ def merge(
show_default=False,
),
] = None,
splits_to_merge: Annotated[
str | None,
typer.Option(
"--split",
"-s",
help=(
"Comma separated list of split names to merge, "
"e.g. `-s val,test` or just `-s train`. "
"If omitted, merges all splits."
),
show_default=False,
),
] = None,
bucket_storage: BucketStorage = bucket_option,
team_id: Annotated[
str | None,
typer.Option(
"--team-id",
"-t",
help="Team ID to use for the new dataset. If not provided, the dataset's current team ID will be used.",
show_default=False,
),
] = None,
):
"""Merge two datasets stored in the same type of bucket."""
check_exists(source_name, bucket_storage)
Expand All @@ -913,6 +958,13 @@ def merge(
):
raise typer.Exit

if splits_to_merge:
split_list = [
s.strip() for s in splits_to_merge.split(",") if s.strip()
]
else:
split_list = None

source_dataset = LuxonisDataset(source_name, bucket_storage=bucket_storage)
target_dataset = LuxonisDataset(target_name, bucket_storage=bucket_storage)

Expand All @@ -922,7 +974,11 @@ def merge(
)

_ = target_dataset.merge_with(
source_dataset, inplace=inplace, new_dataset_name=new_name
source_dataset,
inplace=inplace,
new_dataset_name=new_name,
splits_to_merge=split_list,
team_id=team_id,
)

if inplace:
Expand Down
81 changes: 74 additions & 7 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import suppress
from functools import cached_property
from os import PathLike
from pathlib import Path, PurePosixPath
from typing import Any, Literal, overload

Expand Down Expand Up @@ -273,6 +274,7 @@ def clone(
self,
new_dataset_name: str,
push_to_cloud: bool = True,
splits_to_clone: list[str] | None = None,
team_id: str | None = None,
) -> "LuxonisDataset":
"""Create a new LuxonisDataset that is a local copy of the
Expand All @@ -284,6 +286,11 @@ def clone(
@type push_to_cloud: bool
@param push_to_cloud: Whether to push the new dataset to the
cloud. Only if the current dataset is remote.
@param splits_to_clone: list[str] | None
@type splits_to_clone: Optional list of split names to clone. If
None, all data will be cloned.
@type team_id: str | None
@param team_id: Optional team identifier.
"""
if team_id is None:
team_id = self.team_id
Expand All @@ -302,10 +309,33 @@ def clone(

new_dataset_path = Path(new_dataset.local_path)
new_dataset_path.mkdir(parents=True, exist_ok=True)

if splits_to_clone is not None:
df_self = self._load_df_offline(raise_when_empty=True)
splits_self = self._load_splits(self.metadata_path)
uuids_to_clone = {
uid
for split in splits_to_clone
for uid in splits_self.get(split, [])
}
df_self = df_self.filter(df_self["uuid"].is_in(uuids_to_clone))
splits_self = {
k: v for k, v in splits_self.items() if k in splits_to_clone
}

shutil.copytree(
self.local_path, new_dataset.local_path, dirs_exist_ok=True
self.local_path,
new_dataset.local_path,
dirs_exist_ok=True,
ignore=lambda d, n: self._ignore_files_not_in_uuid_set(
d, n, uuids_to_clone if splits_to_clone else set()
),
)

if splits_to_clone is not None:
new_dataset._save_df_offline(df_self)
new_dataset._save_splits(splits_self)

new_dataset._init_paths()
new_dataset._metadata = self._get_metadata()

Expand All @@ -331,6 +361,8 @@ def merge_with(
other: "LuxonisDataset",
inplace: bool = True,
new_dataset_name: str | None = None,
splits_to_merge: list[str] | None = None,
team_id: str | None = None,
) -> "LuxonisDataset":
"""Merge all data from `other` LuxonisDataset into the current
dataset (in-place or in a new dataset).
Expand All @@ -343,6 +375,10 @@ def merge_with(
@type new_dataset_name: str
@param new_dataset_name: The name of the new dataset to create
if inplace is False.
@type splits_to_merge: list[str] | None
@param splits_to_merge: Optional list of split names to merge.
@type team_id: str | None
@param team_id: Optional team identifier.
"""
if inplace:
target_dataset = self
Expand All @@ -351,7 +387,9 @@ def merge_with(
raise ValueError(
"Cannot merge datasets with different bucket storage types."
)
target_dataset = self.clone(new_dataset_name, push_to_cloud=False)
target_dataset = self.clone(
new_dataset_name, push_to_cloud=False, team_id=team_id
)
else:
raise ValueError(
"You must specify a name for the new dataset "
Expand All @@ -374,13 +412,22 @@ def merge_with(
~df_other["uuid"].is_in(duplicate_uuids)
)

splits_self = self._load_splits(self.metadata_path)
splits_other = self._load_splits(other.metadata_path)
if splits_to_merge is not None:
uuids_to_merge = {
uuid
for split_name in splits_to_merge
for uuid in splits_other.get(split_name, [])
}
df_other = df_other.filter(df_other["uuid"].is_in(uuids_to_merge))
splits_other = {
k: v for k, v in splits_other.items() if k in splits_to_merge
}

df_merged = pl.concat([df_self, df_other])
target_dataset._save_df_offline(df_merged)

splits_self = self._load_splits(self.metadata_path)
splits_other = self._load_splits(
other.metadata_path
) # dict of split names to list of uuids
splits_other = {
split_name: [uuid for uuid in uuids if uuid not in duplicate_uuids]
for split_name, uuids in splits_other.items()
Expand All @@ -390,7 +437,12 @@ def merge_with(

if self.is_remote:
shutil.copytree(
other.media_path, target_dataset.media_path, dirs_exist_ok=True
other.media_path,
target_dataset.media_path,
dirs_exist_ok=True,
ignore=lambda d, n: self._ignore_files_not_in_uuid_set(
d, n, uuids_to_merge if splits_to_merge else set()
),
)
target_dataset.push_to_cloud(
bucket_storage=target_dataset.bucket_storage,
Expand Down Expand Up @@ -418,6 +470,21 @@ def _load_splits(self, path: Path) -> dict[str, list[str]]:
with open(splits_path) as f:
return json.load(f)

def _ignore_files_not_in_uuid_set(
self,
dir_path: PathLike[str] | str,
names: list[str],
uuids_to_keep: set[str],
) -> set[str]:
if not uuids_to_keep:
return set()
ignored: set[str] = set()
for name in names:
full = Path(dir_path) / name
if full.is_file() and full.stem not in uuids_to_keep:
ignored.add(name)
return ignored

def _merge_splits(
self,
splits_self: dict[str, list[str]],
Expand Down
117 changes: 117 additions & 0 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,123 @@ def generator2() -> DatasetIterator:
)


@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
def test_merge_datasets_specific_split(
bucket_storage: BucketStorage,
dataset_name: str,
tempdir: Path,
):
dataset_name = f"{dataset_name}_{bucket_storage.value}"
dataset1_name = f"{dataset_name}_1"
dataset2_name = f"{dataset_name}_2"

def generator1() -> DatasetIterator:
for i in range(3):
img = create_image(i, tempdir)
yield {
"file": img,
"annotation": {
"class": "person",
"boundingbox": {"x": 0.1, "y": 0.1, "w": 0.1, "h": 0.1},
},
}

def generator2() -> DatasetIterator:
for i in range(3, 6):
img = create_image(i, tempdir)
yield {
"file": img,
"annotation": {
"class": "dog",
"boundingbox": {"x": 0.2, "y": 0.2, "w": 0.2, "h": 0.2},
},
}

dataset1 = create_dataset(
dataset1_name,
generator1(),
bucket_storage,
splits={"train": 0.6, "val": 0.4},
)

dataset2 = create_dataset(
dataset2_name,
generator2(),
bucket_storage,
splits={"train": 0.6, "val": 0.4},
)

merged_dataset = dataset1.merge_with(
dataset2,
inplace=False,
new_dataset_name=f"{dataset1_name}_{dataset2_name}_merged",
splits_to_merge=["train"],
)

merged_stats = merged_dataset.get_statistics()
assert {
(item["count"], item["class_name"])
for item in merged_stats["class_distributions"][""]["boundingbox"]
} == {(3, "person"), (2, "dog")}
merged_splits = merged_dataset.get_splits()
dataset1_splits = dataset1.get_splits()
dataset2_splits = dataset2.get_splits()
assert merged_splits is not None
assert dataset1_splits is not None
assert dataset2_splits is not None
assert set(merged_splits["train"]) == set(dataset1_splits["train"]) | set(
dataset2_splits["train"]
)
assert set(merged_splits["val"]) == set(dataset1_splits["val"])

dataset1.delete_dataset(delete_local=True, delete_remote=True)
dataset2.delete_dataset(delete_local=True, delete_remote=True)
merged_dataset.delete_dataset(delete_local=True, delete_remote=True)


@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
def test_clone_dataset_specific_split(
bucket_storage: BucketStorage,
dataset_name: str,
tempdir: Path,
):
def generator() -> DatasetIterator:
for i in range(3):
img = create_image(i, tempdir)
yield {
"file": img,
"annotation": {
"class": "person",
"boundingbox": {"x": 0.1, "y": 0.1, "w": 0.1, "h": 0.1},
},
}

dataset = create_dataset(
dataset_name,
generator(),
bucket_storage,
splits={"train": 0.6, "val": 0.4},
)
cloned_dataset = dataset.clone(
new_dataset_name=f"{dataset_name}_cloned",
splits_to_clone=["train"],
)
dataset_splits = dataset.get_splits()
cloned_splits = cloned_dataset.get_splits()
assert cloned_splits is not None
assert dataset_splits is not None
assert set(cloned_splits["train"]) == set(dataset_splits["train"])
assert "val" not in cloned_splits

cloned_stats = cloned_dataset.get_statistics()
assert {
(item["count"], item["class_name"])
for item in cloned_stats["class_distributions"][""]["boundingbox"]
} == {(2, "person")}

cloned_dataset.delete_dataset(delete_local=True, delete_remote=True)


@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
def test_classes_per_task(dataset_name: str, tempdir: Path):
def generator() -> DatasetIterator:
Expand Down
Loading