@@ -348,6 +348,7 @@ def map( # type: ignore
348
348
self ,
349
349
function : Optional [Union [Callable , str ]] = None ,
350
350
result_document_type : Optional [Union [str , Type [Document ]]] = None ,
351
+ set_batch_size_to_split_size : bool = False ,
351
352
** kwargs ,
352
353
) -> "DatasetDict" :
353
354
"""Applies a function to all documents in the dataset.
@@ -370,6 +371,9 @@ def map( # type: ignore
370
371
string that can be resolved to such a type. If not provided, it is tried to infer it from the
371
372
function signature. If this is not possible, the document type of the input dataset
372
373
is used.
374
+ set_batch_size_to_split_size: If enabled, set the batch_size to the size of the respective split
375
+ when calling map() on it. This is useful to transform whole splits when using it in
376
+ combination with batched=True.
373
377
**kwargs: additional keyword arguments for `datasets.Dataset.map()`
374
378
"""
375
379
@@ -395,6 +399,8 @@ def identity(x):
395
399
for split , dataset in self .items ():
396
400
if isinstance (func , EnterDatasetMixin ):
397
401
func .enter_dataset (dataset = dataset , name = split )
402
+ if set_batch_size_to_split_size :
403
+ map_kwargs ["batch_size" ] = len (dataset )
398
404
result_dict [split ] = dataset .map (** map_kwargs )
399
405
if isinstance (func , ExitDatasetMixin ):
400
406
func .exit_dataset (dataset = result_dict [split ], name = split )
0 commit comments