Skip to content

Commit 8fc0cc9

Browse files
authored
add parameter set_batch_size_to_split_size to DatasetDict.map (#155)
* implement set_batch_size_to_split_size * improve docstring * add test
1 parent ef8f2d7 commit 8fc0cc9

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

src/pie_datasets/core/dataset_dict.py

+6
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def map( # type: ignore
348348
self,
349349
function: Optional[Union[Callable, str]] = None,
350350
result_document_type: Optional[Union[str, Type[Document]]] = None,
351+
set_batch_size_to_split_size: bool = False,
351352
**kwargs,
352353
) -> "DatasetDict":
353354
"""Applies a function to all documents in the dataset.
@@ -370,6 +371,9 @@ def map( # type: ignore
370371
string that can be resolved to such a type. If not provided, it is tried to infer it from the
371372
function signature. If this is not possible, the document type of the input dataset
372373
is used.
374+
set_batch_size_to_split_size: If enabled, set the batch_size to the size of the respective split
375+
when calling map() on it. This is useful to transform whole splits when using it in
376+
combination with batched=True.
373377
**kwargs: additional keyword arguments for `datasets.Dataset.map()`
374378
"""
375379

@@ -395,6 +399,8 @@ def identity(x):
395399
for split, dataset in self.items():
396400
if isinstance(func, EnterDatasetMixin):
397401
func.enter_dataset(dataset=dataset, name=split)
402+
if set_batch_size_to_split_size:
403+
map_kwargs["batch_size"] = len(dataset)
398404
result_dict[split] = dataset.map(**map_kwargs)
399405
if isinstance(func, ExitDatasetMixin):
400406
func.exit_dataset(dataset=result_dict[split], name=split)

tests/unit/core/test_dataset_dict.py

+19
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,25 @@ def exit_dataset_dict(self, dataset_dict: DatasetDict) -> None:
280280
assert doc1 == doc2
281281

282282

283+
def test_map_set_max_batch_size(dataset_dict):
284+
def join_docs(docs):
285+
return [TextBasedDocument(text=" ".join([doc.text for doc in docs]))]
286+
287+
dataset_dict_mapped = dataset_dict.map(
288+
join_docs,
289+
batched=True,
290+
set_batch_size_to_split_size=True,
291+
result_document_type=TextBasedDocument,
292+
)
293+
assert dataset_dict_mapped.document_type is TextBasedDocument
294+
for split in dataset_dict:
295+
assert len(dataset_dict_mapped[split]) == 1
296+
new_doc = dataset_dict_mapped[split][0]
297+
assert isinstance(new_doc, TextBasedDocument)
298+
original_texts = [doc.text for doc in dataset_dict[split]]
299+
assert new_doc.text == " ".join(original_texts)
300+
301+
283302
def test_select(dataset_dict):
284303
# select documents by index
285304
dataset_dict_selected = dataset_dict.select(

0 commit comments

Comments
 (0)