Skip to content

Commit 92d4d38

Browse files
authored
Fix the dataset merge, export on Machine B (#327)
1 parent 6a7ac25 commit 92d4d38

File tree

4 files changed

+105
-9
lines changed

4 files changed

+105
-9
lines changed

luxonis_ml/data/__main__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,15 @@ def clone(
830830
),
831831
] = True,
832832
bucket_storage: BucketStorage = bucket_option,
833+
team_id: Annotated[
834+
str | None,
835+
typer.Option(
836+
"--team-id",
837+
"-t",
838+
help="Team ID to use for the new dataset. If not provided, the dataset's current team ID will be used.",
839+
show_default=False,
840+
),
841+
] = None,
833842
):
834843
"""Clone an existing dataset with a new name.
835844
@@ -847,7 +856,9 @@ def clone(
847856

848857
print(f"Cloning dataset '{name}' to '{new_name}'...")
849858
dataset = LuxonisDataset(name, bucket_storage=bucket_storage)
850-
dataset.clone(new_dataset_name=new_name, push_to_cloud=push_to_cloud)
859+
dataset.clone(
860+
new_dataset_name=new_name, push_to_cloud=push_to_cloud, team_id=team_id
861+
)
851862
print(f"[green]Dataset '{name}' successfully cloned to '{new_name}'.")
852863

853864

luxonis_ml/data/datasets/luxonis_dataset.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import math
33
import shutil
4+
import sys
45
from collections import defaultdict
56
from collections.abc import Iterable, Mapping, Sequence
67
from concurrent.futures import ThreadPoolExecutor
@@ -269,7 +270,10 @@ def _merge_metadata_with(self, other: "LuxonisDataset") -> None:
269270
self._write_metadata()
270271

271272
def clone(
272-
self, new_dataset_name: str, push_to_cloud: bool = True
273+
self,
274+
new_dataset_name: str,
275+
push_to_cloud: bool = True,
276+
team_id: str | None = None,
273277
) -> "LuxonisDataset":
274278
"""Create a new LuxonisDataset that is a local copy of the
275279
current dataset. Cloned dataset will overwrite the existing
@@ -281,10 +285,12 @@ def clone(
281285
@param push_to_cloud: Whether to push the new dataset to the
282286
cloud. Only if the current dataset is remote.
283287
"""
288+
if team_id is None:
289+
team_id = self.team_id
284290

285291
new_dataset = LuxonisDataset(
286292
dataset_name=new_dataset_name,
287-
team_id=self.team_id,
293+
team_id=team_id,
288294
bucket_type=self.bucket_type,
289295
bucket_storage=self.bucket_storage,
290296
delete_local=True,
@@ -391,6 +397,18 @@ def merge_with(
391397
update_mode=UpdateMode.MISSING,
392398
)
393399

400+
for entry in (
401+
df_other.select(["uuid", "file"])
402+
.unique(subset=["uuid"])
403+
.to_dicts()
404+
):
405+
uid, rel_file = entry["uuid"], entry["file"]
406+
src_path = other.media_path / f"{uid}{Path(rel_file).suffix}"
407+
dst_path = target_dataset.media_path / src_path.name
408+
if src_path.exists() and not dst_path.exists():
409+
dst_path.parent.mkdir(parents=True, exist_ok=True)
410+
shutil.copy(src_path, dst_path)
411+
394412
target_dataset._merge_metadata_with(other)
395413

396414
return target_dataset
@@ -1422,12 +1440,11 @@ def _dump_annotations(
14221440
description="Exporting ...",
14231441
):
14241442
uuid = row[7]
1425-
if self.is_remote:
1443+
file = Path(row[-1])
1444+
if self.is_remote or not file.exists():
14261445
file_extension = row[0].rsplit(".", 1)[-1]
14271446
file = self.media_path / f"{uuid}.{file_extension}"
14281447
assert file.exists()
1429-
else:
1430-
file = Path(row[-1])
14311448

14321449
split = None
14331450
for s, uuids in splits.items():
@@ -1445,9 +1462,13 @@ def _dump_annotations(
14451462

14461463
if file not in image_indices:
14471464
file_size = file.stat().st_size
1465+
annotations_size = sum(
1466+
sys.getsizeof(lst) for lst in annotations.values()
1467+
)
14481468
if (
14491469
max_partition_size
1450-
and current_size + file_size > max_partition_size
1470+
and current_size + file_size + annotations_size
1471+
> max_partition_size
14511472
):
14521473
_dump_annotations(
14531474
annotations, output_path, self.identifier, part
@@ -1511,6 +1532,10 @@ def _dump_annotations(
15111532
record["annotation"][task_type] = data
15121533
annotations[split].append(record)
15131534

1535+
elif task_type == "metadata/text":
1536+
record["annotation"]["metadata"] = {"text": data}
1537+
annotations[split].append(record)
1538+
15141539
_dump_annotations(annotations, output_path, self.identifier, part)
15151540

15161541
if zip_output:

tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def base_tempdir(worker_id: str):
139139

140140

141141
@pytest.fixture
142-
def tempdir(base_tempdir: Path, randint: int) -> Path:
142+
def tempdir(base_tempdir: Path, randint: int) -> Generator[Path, None, None]:
143143
t = time.time()
144144
unique_id = randint
145145
while True:
@@ -155,7 +155,9 @@ def tempdir(base_tempdir: Path, randint: int) -> Path:
155155

156156
path.mkdir(exist_ok=True)
157157

158-
return path
158+
yield path
159+
160+
shutil.rmtree(path, ignore_errors=True)
159161

160162

161163
@pytest.fixture(scope="session")

tests/test_data/test_dataset.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,3 +929,61 @@ def generator(start: int, end: int) -> DatasetIterator:
929929
loader = LuxonisLoader(cloud_dataset)
930930
assert sum(1 for _ in loader) == 3
931931
assert cloud_dataset.get_statistics() == original_stats
932+
933+
934+
@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
935+
def test_merge_on_different_machines(dataset_name: str, tempdir: Path):
936+
def generator(start: int, end: int) -> DatasetIterator:
937+
"""Generate sample dataset items with bounding boxes."""
938+
for i in range(start, end):
939+
img = create_image(i, tempdir)
940+
yield {
941+
"file": img,
942+
"annotation": {
943+
"class": "person",
944+
"boundingbox": {"x": 0.1, "y": 0.1, "w": 0.1, "h": 0.1},
945+
"instance_id": i,
946+
},
947+
}
948+
949+
dataset1 = create_dataset(
950+
dataset_name + "_1",
951+
generator(0, 3),
952+
bucket_storage=BucketStorage.GCS,
953+
delete_local=True,
954+
delete_remote=True,
955+
splits=(1, 0, 0),
956+
)
957+
dataset2 = create_dataset(
958+
dataset_name + "_2",
959+
generator(3, 6),
960+
bucket_storage=BucketStorage.GCS,
961+
delete_local=True,
962+
delete_remote=True,
963+
splits=(1, 0, 0),
964+
)
965+
shutil.rmtree(tempdir)
966+
dataset1.pull_from_cloud()
967+
dataset2.pull_from_cloud()
968+
dataset1.delete_dataset(delete_remote=True)
969+
dataset2.delete_dataset(delete_remote=True)
970+
dataset1 = LuxonisDataset(dataset_name + "_1")
971+
dataset2 = LuxonisDataset(dataset_name + "_2")
972+
assert len(list(dataset1.media_path.glob("*"))) == 3
973+
assert len(list(dataset2.media_path.glob("*"))) == 3
974+
dataset3 = dataset1.merge_with(
975+
dataset2, inplace=False, new_dataset_name=dataset_name
976+
)
977+
loader = LuxonisLoader(dataset3)
978+
assert sum(1 for _ in loader) == 6
979+
dataset3.export(tempdir)
980+
assert (
981+
len(
982+
list(
983+
Path.cwd().glob(
984+
f"{tempdir}/{dataset3.dataset_name}/train/images/*"
985+
)
986+
)
987+
)
988+
== 6
989+
)

0 commit comments

Comments
 (0)