@@ -68,6 +68,7 @@ def __init__(
6868 self ._dataloader_params ,
6969 self ._forward_params ,
7070 self ._postprocess_params ,
71+ self ._dataset_map_params ,
7172 ) = self ._sanitize_parameters (** kwargs )
7273
7374 def save_pretrained (self , save_directory : str ):
@@ -161,7 +162,7 @@ def _ensure_tensor_on_device(self, inputs, device):
161162
162163 def _sanitize_parameters (
163164 self , ** pipeline_parameters
164- ) -> Tuple [Dict [str , Any ], Dict [str , Any ], Dict [str , Any ], Dict [str , Any ]]:
165+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ], Dict [str , Any ], Dict [str , Any ], Dict [ str , Any ] ]:
165166 """
166167 _sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
167168 methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`,
@@ -175,6 +176,7 @@ def _sanitize_parameters(
175176 dataloader_params = {}
176177 forward_parameters = {}
177178 postprocess_parameters : Dict [str , Any ] = {}
179+ dataset_map_parameters = {}
178180
179181 # set preprocess parameters
180182 field = pipeline_parameters .get ("predict_field" )
@@ -196,7 +198,17 @@ def _sanitize_parameters(
196198 if p_name in pipeline_parameters :
197199 postprocess_parameters [p_name ] = pipeline_parameters [p_name ]
198200
199- return preprocess_parameters , dataloader_params , forward_parameters , postprocess_parameters
201+ for p_name in ["document_batch_size" ]:
202+ if p_name in pipeline_parameters :
203+ dataset_map_parameters ["batch_size" ] = pipeline_parameters [p_name ]
204+
205+ return (
206+ preprocess_parameters ,
207+ dataloader_params ,
208+ forward_parameters ,
209+ postprocess_parameters ,
210+ dataset_map_parameters ,
211+ )
200212
201213 def preprocess (
202214 self ,
@@ -283,27 +295,55 @@ def get_dataloader(
283295
284296 return dataloader
285297
298+ def _process_documents (
299+ self ,
300+ documents : Sequence [Document ],
301+ preprocess_params : Dict [str , Any ],
302+ dataloader_params : Dict [str , Any ],
303+ forward_params : Dict [str , Any ],
304+ postprocess_params : Dict [str , Any ],
305+ ) -> Sequence [Document ]:
306+ # This creates encodings from the documents. It modifies the documents and may produce multiple entries per
307+ # document.
308+ model_inputs = self .preprocess (documents , ** preprocess_params )
309+ # Create a dataloader from the model inputs. This uses taskmodule.collate().
310+ dataloader = self .get_dataloader (model_inputs = model_inputs , ** dataloader_params )
311+
312+ show_progress_bar = forward_params .pop ("show_progress_bar" , False )
313+ model_outputs : List = []
314+ with torch .no_grad ():
315+ for batch in tqdm .tqdm (dataloader , desc = "inference" , disable = not show_progress_bar ):
316+ output = self .forward (batch , ** forward_params )
317+ processed_output = self .taskmodule .unbatch_output (output )
318+ model_outputs .extend (processed_output )
319+
320+ assert len (model_inputs ) == len (
321+ model_outputs
322+ ), f"length mismatch: len(model_inputs) [{ len (model_inputs )} ] != len(model_outputs) [{ len (model_outputs )} ]"
323+
324+ documents = self .postprocess (
325+ model_inputs = model_inputs ,
326+ model_outputs = model_outputs ,
327+ ** postprocess_params ,
328+ )
329+ return documents
330+
286331 def __call__ (
287332 self ,
288333 documents : Union [Document , Sequence [Document ], Dataset ],
289334 * args ,
290335 ** kwargs ,
291- ) -> Union [Document , Sequence [Document ]]:
336+ ) -> Union [Document , Sequence [Document ], Dataset ]:
292337 if args :
293338 logger .warning (f"Ignoring args : { args } " )
294339 (
295340 preprocess_params ,
296341 dataloader_params ,
297342 forward_params ,
298343 postprocess_params ,
344+ dataset_map_params ,
299345 ) = self ._sanitize_parameters (** kwargs )
300346
301- in_place : bool = postprocess_params .get ("inplace" , True )
302- if in_place and isinstance (documents , Dataset ):
303- raise InplaceNotSupportedException (
304- "Datasets can't be modified in place. Please set inplace=False."
305- )
306-
307347 if "TOKENIZERS_PARALLELISM" not in os .environ :
308348 logger .info (
309349 "Disabling tokenizer parallelism, we're using DataLoader multithreading already"
@@ -315,6 +355,7 @@ def __call__(
315355 dataloader_params = {** self ._dataloader_params , ** dataloader_params }
316356 forward_params = {** self ._forward_params , ** forward_params }
317357 postprocess_params = {** self ._postprocess_params , ** postprocess_params }
358+ dataset_map_params = {** self ._dataset_map_params , ** dataset_map_params }
318359
319360 self .call_count += 1
320361 if self .call_count > 10 and self .device .type == "cuda" :
@@ -328,30 +369,39 @@ def __call__(
328369 single_document = True
329370 documents = [documents ]
330371
331- # This creates encodings from the documents. It modifies the documents and may produce multiple entries per
332- # document.
333- model_inputs = self .preprocess (documents , ** preprocess_params )
334- # Create a dataloader from the model inputs. This uses taskmodule.collate().
335- dataloader = self .get_dataloader (model_inputs = model_inputs , ** dataloader_params )
336-
337- show_progress_bar = forward_params .pop ("show_progress_bar" , False )
338- model_outputs : List = []
339- with torch .no_grad ():
340- for batch in tqdm .tqdm (dataloader , desc = "inference" , disable = not show_progress_bar ):
341- output = self .forward (batch , ** forward_params )
342- processed_output = self .taskmodule .unbatch_output (output )
343- model_outputs .extend (processed_output )
344-
345- assert len (model_inputs ) == len (
346- model_outputs
347- ), f"length mismatch: len(model_inputs) [{ len (model_inputs )} ] != len(model_outputs) [{ len (model_outputs )} ]"
372+ processed_documents : Union [Sequence [Document ], Dataset ]
373+ if isinstance (documents , Dataset ):
374+ in_place : bool = postprocess_params .get ("inplace" , True )
375+ if in_place :
376+ raise InplaceNotSupportedException (
377+ "Datasets can't be modified in place. Please set inplace=False."
378+ )
379+ # do not show inner progress bar
380+ forward_params ["show_progress_bar" ] = False
381+
382+ processed_documents = documents .map (
383+ self ._process_documents ,
384+ fn_kwargs = dict (
385+ preprocess_params = preprocess_params ,
386+ dataloader_params = dataloader_params ,
387+ forward_params = forward_params ,
388+ postprocess_params = postprocess_params ,
389+ ),
390+ batched = True ,
391+ ** dataset_map_params ,
392+ )
393+ else :
394+ processed_documents = self ._process_documents (
395+ documents = documents ,
396+ preprocess_params = preprocess_params ,
397+ dataloader_params = dataloader_params ,
398+ forward_params = forward_params ,
399+ postprocess_params = postprocess_params ,
400+ )
348401
349- documents = self .postprocess (
350- model_inputs = model_inputs ,
351- model_outputs = model_outputs ,
352- ** postprocess_params ,
353- )
354402 if single_document :
355- return documents [0 ]
403+ # TODO: fix "type: ignore" (if processed_documents is a Dataset, mypy assumes the result is Dict[Any, Any])
404+ processed_document : Document = processed_documents [0 ] # type: ignore
405+ return processed_document
356406 else :
357- return documents
407+ return processed_documents
0 commit comments