@@ -70,7 +70,6 @@ def __init__(
7070 self ._dataloader_params ,
7171 self ._forward_params ,
7272 self ._postprocess_params ,
73- self ._dataset_map_params ,
7473 ) = self ._sanitize_parameters (** kwargs )
7574
7675 def save_pretrained (self , save_directory : str ):
@@ -167,7 +166,7 @@ def _ensure_tensor_on_device(self, inputs, device):
167166
168167 def _sanitize_parameters (
169168 self , ** pipeline_parameters
170- ) -> Tuple [Dict [str , Any ], Dict [str , Any ], Dict [str , Any ], Dict [str , Any ], Dict [ str , Any ] ]:
169+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ], Dict [str , Any ], Dict [str , Any ]]:
171170 """
172171 _sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
173172 methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`,
@@ -181,7 +180,6 @@ def _sanitize_parameters(
181180 dataloader_params = {}
182181 forward_parameters = {}
183182 postprocess_parameters : Dict [str , Any ] = {}
184- dataset_map_parameters = {}
185183
186184 # set preprocess parameters
187185 for p_name in ["document_batch_size" ]:
@@ -203,16 +201,11 @@ def _sanitize_parameters(
203201 if p_name in pipeline_parameters :
204202 postprocess_parameters [p_name ] = pipeline_parameters [p_name ]
205203
206- for p_name in ["document_batch_size" ]:
207- if p_name in pipeline_parameters :
208- dataset_map_parameters ["batch_size" ] = pipeline_parameters [p_name ]
209-
210204 return (
211205 preprocess_parameters ,
212206 dataloader_params ,
213207 forward_parameters ,
214208 postprocess_parameters ,
215- dataset_map_parameters ,
216209 )
217210
218211 def preprocess (
@@ -356,7 +349,6 @@ def __call__(
356349 dataloader_params ,
357350 forward_params ,
358351 postprocess_params ,
359- dataset_map_params ,
360352 ) = self ._sanitize_parameters (** kwargs )
361353
362354 if "TOKENIZERS_PARALLELISM" not in os .environ :
@@ -370,7 +362,6 @@ def __call__(
370362 dataloader_params = {** self ._dataloader_params , ** dataloader_params }
371363 forward_params = {** self ._forward_params , ** forward_params }
372364 postprocess_params = {** self ._postprocess_params , ** postprocess_params }
373- dataset_map_params = {** self ._dataset_map_params , ** dataset_map_params }
374365
375366 self .call_count += 1
376367 if self .call_count > 10 and self .device .type == "cuda" :
@@ -408,7 +399,6 @@ def __call__(
408399 postprocess_params = postprocess_params ,
409400 ),
410401 batched = True ,
411- ** dataset_map_params ,
412402 )
413403 finally :
414404 if was_caching_enabled :
0 commit comments