@@ -767,6 +767,123 @@ def generator2() -> DatasetIterator:
767767 )
768768
769769
770+ @pytest .mark .dependency (name = "test_dataset[BucketStorage.LOCAL]" )
771+ def test_merge_datasets_specific_split (
772+ bucket_storage : BucketStorage ,
773+ dataset_name : str ,
774+ tempdir : Path ,
775+ ):
776+ dataset_name = f"{ dataset_name } _{ bucket_storage .value } "
777+ dataset1_name = f"{ dataset_name } _1"
778+ dataset2_name = f"{ dataset_name } _2"
779+
780+ def generator1 () -> DatasetIterator :
781+ for i in range (3 ):
782+ img = create_image (i , tempdir )
783+ yield {
784+ "file" : img ,
785+ "annotation" : {
786+ "class" : "person" ,
787+ "boundingbox" : {"x" : 0.1 , "y" : 0.1 , "w" : 0.1 , "h" : 0.1 },
788+ },
789+ }
790+
791+ def generator2 () -> DatasetIterator :
792+ for i in range (3 , 6 ):
793+ img = create_image (i , tempdir )
794+ yield {
795+ "file" : img ,
796+ "annotation" : {
797+ "class" : "dog" ,
798+ "boundingbox" : {"x" : 0.2 , "y" : 0.2 , "w" : 0.2 , "h" : 0.2 },
799+ },
800+ }
801+
802+ dataset1 = create_dataset (
803+ dataset1_name ,
804+ generator1 (),
805+ bucket_storage ,
806+ splits = {"train" : 0.6 , "val" : 0.4 },
807+ )
808+
809+ dataset2 = create_dataset (
810+ dataset2_name ,
811+ generator2 (),
812+ bucket_storage ,
813+ splits = {"train" : 0.6 , "val" : 0.4 },
814+ )
815+
816+ merged_dataset = dataset1 .merge_with (
817+ dataset2 ,
818+ inplace = False ,
819+ new_dataset_name = f"{ dataset1_name } _{ dataset2_name } _merged" ,
820+ splits_to_merge = ["train" ],
821+ )
822+
823+ merged_stats = merged_dataset .get_statistics ()
824+ assert {
825+ (item ["count" ], item ["class_name" ])
826+ for item in merged_stats ["class_distributions" ]["" ]["boundingbox" ]
827+ } == {(3 , "person" ), (2 , "dog" )}
828+ merged_splits = merged_dataset .get_splits ()
829+ dataset1_splits = dataset1 .get_splits ()
830+ dataset2_splits = dataset2 .get_splits ()
831+ assert merged_splits is not None
832+ assert dataset1_splits is not None
833+ assert dataset2_splits is not None
834+ assert set (merged_splits ["train" ]) == set (dataset1_splits ["train" ]) | set (
835+ dataset2_splits ["train" ]
836+ )
837+ assert set (merged_splits ["val" ]) == set (dataset1_splits ["val" ])
838+
839+ dataset1 .delete_dataset (delete_local = True , delete_remote = True )
840+ dataset2 .delete_dataset (delete_local = True , delete_remote = True )
841+ merged_dataset .delete_dataset (delete_local = True , delete_remote = True )
842+
843+
844+ @pytest .mark .dependency (name = "test_dataset[BucketStorage.LOCAL]" )
845+ def test_clone_dataset_specific_split (
846+ bucket_storage : BucketStorage ,
847+ dataset_name : str ,
848+ tempdir : Path ,
849+ ):
850+ def generator () -> DatasetIterator :
851+ for i in range (3 ):
852+ img = create_image (i , tempdir )
853+ yield {
854+ "file" : img ,
855+ "annotation" : {
856+ "class" : "person" ,
857+ "boundingbox" : {"x" : 0.1 , "y" : 0.1 , "w" : 0.1 , "h" : 0.1 },
858+ },
859+ }
860+
861+ dataset = create_dataset (
862+ dataset_name ,
863+ generator (),
864+ bucket_storage ,
865+ splits = {"train" : 0.6 , "val" : 0.4 },
866+ )
867+ cloned_dataset = dataset .clone (
868+ new_dataset_name = f"{ dataset_name } _cloned" ,
869+ splits_to_clone = ["train" ],
870+ )
871+ dataset_splits = dataset .get_splits ()
872+ cloned_splits = cloned_dataset .get_splits ()
873+ assert cloned_splits is not None
874+ assert dataset_splits is not None
875+ assert set (cloned_splits ["train" ]) == set (dataset_splits ["train" ])
876+ assert "val" not in cloned_splits
877+
878+ cloned_stats = cloned_dataset .get_statistics ()
879+ assert {
880+ (item ["count" ], item ["class_name" ])
881+ for item in cloned_stats ["class_distributions" ]["" ]["boundingbox" ]
882+ } == {(2 , "person" )}
883+
884+ cloned_dataset .delete_dataset (delete_local = True , delete_remote = True )
885+
886+
770887@pytest .mark .dependency (name = "test_dataset[BucketStorage.LOCAL]" )
771888def test_classes_per_task (dataset_name : str , tempdir : Path ):
772889 def generator () -> DatasetIterator :
0 commit comments