Skip to content

Commit 3e5a696

Browse files
authored
extend split_mappings format (#161)
1 parent 2b3ba52 commit 3e5a696

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

src/pie_datasets/core/dataset_dict.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -737,27 +737,37 @@ def load_dataset(*args, **kwargs) -> Union[DatasetDict, Dataset, IterableDataset
737737

738738

739739
def concatenate_dataset_dicts(
740-
inputs: Dict[str, DatasetDict], split_mappings: Dict[str, Dict[str, str]], clear_metadata: bool
740+
inputs: Dict[str, DatasetDict],
741+
split_mappings: Dict[str, Dict[str, Union[str, List[str]]]],
742+
clear_metadata: bool,
741743
):
742744
"""Concatenate the splits of multiple dataset dicts into a single one. Dataset name will be
743745
saved in Metadata.
744746
745747
Args:
746748
inputs: A mapping from dataset names to dataset dicts that contain the splits to concatenate.
747-
split_mappings: A mapping from target split names to mappings from input dataset names to
748-
source split names.
749+
split_mappings: A mapping from target split name to mappings from input dataset name to
750+
source split name or list of names.
749751
clear_metadata: Whether to clear the metadata before concatenating.
750752
751753
Returns: A dataset dict with keys in split_names as splits and content from the merged input
752754
dataset dicts.
753755
"""
754756

755-
input_splits = {}
757+
input_splits: Dict[str, Dict[str, Union[Dataset, IterableDataset]]] = {}
756758
for target_split_name, mapping in split_mappings.items():
757-
input_splits[target_split_name] = {
758-
ds_name: inputs[ds_name][source_split_name]
759-
for ds_name, source_split_name in mapping.items()
760-
}
759+
input_splits[target_split_name] = {}
760+
for ds_name, source_split_name in mapping.items():
761+
if isinstance(source_split_name, str):
762+
input_splits[target_split_name][ds_name] = inputs[ds_name][source_split_name]
763+
elif isinstance(source_split_name, list):
764+
input_splits[target_split_name][ds_name] = concatenate_datasets(
765+
[
766+
inputs[ds_name][_source_split_name]
767+
for _source_split_name in source_split_name
768+
],
769+
clear_metadata=clear_metadata,
770+
)
761771

762772
result = DatasetDict(
763773
{

tests/unit/core/test_dataset_dict.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,3 +724,19 @@ def test_concatenate_dataset_dicts(tbga_extract, comagc_extract):
724724
assert all(
725725
[ds.metadata["dataset_name"] in ["tbga", "comagc"] for ds in concatenated_dataset["train"]]
726726
)
727+
728+
concatenated_dataset_with_list_in_mapping = concatenate_dataset_dicts(
729+
inputs={"tbga": tbga_extract, "comagc": comagc_extract},
730+
split_mappings={"train": {"tbga": ["train", "test"], "comagc": "train"}},
731+
clear_metadata=True,
732+
)
733+
734+
assert len(concatenated_dataset_with_list_in_mapping["train"]) == len(
735+
tbga_extract["train"]
736+
) + len(tbga_extract["test"]) + len(comagc_extract["train"])
737+
assert all(
738+
[
739+
ds.metadata["dataset_name"] in ["tbga", "comagc"]
740+
for ds in concatenated_dataset_with_list_in_mapping["train"]
741+
]
742+
)

0 commit comments

Comments
 (0)