Skip to content

Commit 061fe61

Browse files
committed
add tests
1 parent 6f37f5f commit 061fe61

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

tests/test_data/test_dataset.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,123 @@ def generator2() -> DatasetIterator:
767767
)
768768

769769

770+
@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
771+
def test_merge_datasets_specific_split(
772+
bucket_storage: BucketStorage,
773+
dataset_name: str,
774+
tempdir: Path,
775+
):
776+
dataset_name = f"{dataset_name}_{bucket_storage.value}"
777+
dataset1_name = f"{dataset_name}_1"
778+
dataset2_name = f"{dataset_name}_2"
779+
780+
def generator1() -> DatasetIterator:
781+
for i in range(3):
782+
img = create_image(i, tempdir)
783+
yield {
784+
"file": img,
785+
"annotation": {
786+
"class": "person",
787+
"boundingbox": {"x": 0.1, "y": 0.1, "w": 0.1, "h": 0.1},
788+
},
789+
}
790+
791+
def generator2() -> DatasetIterator:
792+
for i in range(3, 6):
793+
img = create_image(i, tempdir)
794+
yield {
795+
"file": img,
796+
"annotation": {
797+
"class": "dog",
798+
"boundingbox": {"x": 0.2, "y": 0.2, "w": 0.2, "h": 0.2},
799+
},
800+
}
801+
802+
dataset1 = create_dataset(
803+
dataset1_name,
804+
generator1(),
805+
bucket_storage,
806+
splits={"train": 0.6, "val": 0.4},
807+
)
808+
809+
dataset2 = create_dataset(
810+
dataset2_name,
811+
generator2(),
812+
bucket_storage,
813+
splits={"train": 0.6, "val": 0.4},
814+
)
815+
816+
merged_dataset = dataset1.merge_with(
817+
dataset2,
818+
inplace=False,
819+
new_dataset_name=f"{dataset1_name}_{dataset2_name}_merged",
820+
splits_to_merge=["train"],
821+
)
822+
823+
merged_stats = merged_dataset.get_statistics()
824+
assert {
825+
(item["count"], item["class_name"])
826+
for item in merged_stats["class_distributions"][""]["boundingbox"]
827+
} == {(3, "person"), (2, "dog")}
828+
merged_splits = merged_dataset.get_splits()
829+
dataset1_splits = dataset1.get_splits()
830+
dataset2_splits = dataset2.get_splits()
831+
assert merged_splits is not None
832+
assert dataset1_splits is not None
833+
assert dataset2_splits is not None
834+
assert set(merged_splits["train"]) == set(dataset1_splits["train"]) | set(
835+
dataset2_splits["train"]
836+
)
837+
assert set(merged_splits["val"]) == set(dataset1_splits["val"])
838+
839+
dataset1.delete_dataset(delete_local=True, delete_remote=True)
840+
dataset2.delete_dataset(delete_local=True, delete_remote=True)
841+
merged_dataset.delete_dataset(delete_local=True, delete_remote=True)
842+
843+
844+
@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
845+
def test_clone_dataset_specific_split(
846+
bucket_storage: BucketStorage,
847+
dataset_name: str,
848+
tempdir: Path,
849+
):
850+
def generator() -> DatasetIterator:
851+
for i in range(3):
852+
img = create_image(i, tempdir)
853+
yield {
854+
"file": img,
855+
"annotation": {
856+
"class": "person",
857+
"boundingbox": {"x": 0.1, "y": 0.1, "w": 0.1, "h": 0.1},
858+
},
859+
}
860+
861+
dataset = create_dataset(
862+
dataset_name,
863+
generator(),
864+
bucket_storage,
865+
splits={"train": 0.6, "val": 0.4},
866+
)
867+
cloned_dataset = dataset.clone(
868+
new_dataset_name=f"{dataset_name}_cloned",
869+
splits_to_clone=["train"],
870+
)
871+
dataset_splits = dataset.get_splits()
872+
cloned_splits = cloned_dataset.get_splits()
873+
assert cloned_splits is not None
874+
assert dataset_splits is not None
875+
assert set(cloned_splits["train"]) == set(dataset_splits["train"])
876+
assert "val" not in cloned_splits
877+
878+
cloned_stats = cloned_dataset.get_statistics()
879+
assert {
880+
(item["count"], item["class_name"])
881+
for item in cloned_stats["class_distributions"][""]["boundingbox"]
882+
} == {(2, "person")}
883+
884+
cloned_dataset.delete_dataset(delete_local=True, delete_remote=True)
885+
886+
770887
@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
771888
def test_classes_per_task(dataset_name: str, tempdir: Path):
772889
def generator() -> DatasetIterator:

0 commit comments

Comments
 (0)