Skip to content

Commit 2f72673

Browse files
committed
use dataset.map in pipeline
1 parent 3264b07 commit 2f72673

File tree

1 file changed

+83
-33
lines changed

1 file changed

+83
-33
lines changed

src/pytorch_ie/pipeline.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
self._dataloader_params,
6969
self._forward_params,
7070
self._postprocess_params,
71+
self._dataset_map_params,
7172
) = self._sanitize_parameters(**kwargs)
7273

7374
def save_pretrained(self, save_directory: str):
@@ -161,7 +162,7 @@ def _ensure_tensor_on_device(self, inputs, device):
161162

162163
def _sanitize_parameters(
163164
self, **pipeline_parameters
164-
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
165+
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
165166
"""
166167
_sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
167168
methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`,
@@ -175,6 +176,7 @@ def _sanitize_parameters(
175176
dataloader_params = {}
176177
forward_parameters = {}
177178
postprocess_parameters: Dict[str, Any] = {}
179+
dataset_map_parameters = {}
178180

179181
# set preprocess parameters
180182
field = pipeline_parameters.get("predict_field")
@@ -196,7 +198,17 @@ def _sanitize_parameters(
196198
if p_name in pipeline_parameters:
197199
postprocess_parameters[p_name] = pipeline_parameters[p_name]
198200

199-
return preprocess_parameters, dataloader_params, forward_parameters, postprocess_parameters
201+
for p_name in ["document_batch_size"]:
202+
if p_name in pipeline_parameters:
203+
dataset_map_parameters["batch_size"] = pipeline_parameters[p_name]
204+
205+
return (
206+
preprocess_parameters,
207+
dataloader_params,
208+
forward_parameters,
209+
postprocess_parameters,
210+
dataset_map_parameters,
211+
)
200212

201213
def preprocess(
202214
self,
@@ -283,27 +295,55 @@ def get_dataloader(
283295

284296
return dataloader
285297

298+
def _process_documents(
299+
self,
300+
documents: Sequence[Document],
301+
preprocess_params: Dict[str, Any],
302+
dataloader_params: Dict[str, Any],
303+
forward_params: Dict[str, Any],
304+
postprocess_params: Dict[str, Any],
305+
) -> Sequence[Document]:
306+
# This creates encodings from the documents. It modifies the documents and may produce multiple entries per
307+
# document.
308+
model_inputs = self.preprocess(documents, **preprocess_params)
309+
# Create a dataloader from the model inputs. This uses taskmodule.collate().
310+
dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params)
311+
312+
show_progress_bar = forward_params.pop("show_progress_bar", False)
313+
model_outputs: List = []
314+
with torch.no_grad():
315+
for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar):
316+
output = self.forward(batch, **forward_params)
317+
processed_output = self.taskmodule.unbatch_output(output)
318+
model_outputs.extend(processed_output)
319+
320+
assert len(model_inputs) == len(
321+
model_outputs
322+
), f"length mismatch: len(model_inputs) [{len(model_inputs)}] != len(model_outputs) [{len(model_outputs)}]"
323+
324+
documents = self.postprocess(
325+
model_inputs=model_inputs,
326+
model_outputs=model_outputs,
327+
**postprocess_params,
328+
)
329+
return documents
330+
286331
def __call__(
287332
self,
288333
documents: Union[Document, Sequence[Document], Dataset],
289334
*args,
290335
**kwargs,
291-
) -> Union[Document, Sequence[Document]]:
336+
) -> Union[Document, Sequence[Document], Dataset]:
292337
if args:
293338
logger.warning(f"Ignoring args : {args}")
294339
(
295340
preprocess_params,
296341
dataloader_params,
297342
forward_params,
298343
postprocess_params,
344+
dataset_map_params,
299345
) = self._sanitize_parameters(**kwargs)
300346

301-
in_place: bool = postprocess_params.get("inplace", True)
302-
if in_place and isinstance(documents, Dataset):
303-
raise InplaceNotSupportedException(
304-
"Datasets can't be modified in place. Please set inplace=False."
305-
)
306-
307347
if "TOKENIZERS_PARALLELISM" not in os.environ:
308348
logger.info(
309349
"Disabling tokenizer parallelism, we're using DataLoader multithreading already"
@@ -315,6 +355,7 @@ def __call__(
315355
dataloader_params = {**self._dataloader_params, **dataloader_params}
316356
forward_params = {**self._forward_params, **forward_params}
317357
postprocess_params = {**self._postprocess_params, **postprocess_params}
358+
dataset_map_params = {**self._dataset_map_params, **dataset_map_params}
318359

319360
self.call_count += 1
320361
if self.call_count > 10 and self.device.type == "cuda":
@@ -328,30 +369,39 @@ def __call__(
328369
single_document = True
329370
documents = [documents]
330371

331-
# This creates encodings from the documents. It modifies the documents and may produce multiple entries per
332-
# document.
333-
model_inputs = self.preprocess(documents, **preprocess_params)
334-
# Create a dataloader from the model inputs. This uses taskmodule.collate().
335-
dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params)
336-
337-
show_progress_bar = forward_params.pop("show_progress_bar", False)
338-
model_outputs: List = []
339-
with torch.no_grad():
340-
for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar):
341-
output = self.forward(batch, **forward_params)
342-
processed_output = self.taskmodule.unbatch_output(output)
343-
model_outputs.extend(processed_output)
344-
345-
assert len(model_inputs) == len(
346-
model_outputs
347-
), f"length mismatch: len(model_inputs) [{len(model_inputs)}] != len(model_outputs) [{len(model_outputs)}]"
372+
processed_documents: Union[Sequence[Document], Dataset]
373+
if isinstance(documents, Dataset):
374+
in_place: bool = postprocess_params.get("inplace", True)
375+
if in_place:
376+
raise InplaceNotSupportedException(
377+
"Datasets can't be modified in place. Please set inplace=False."
378+
)
379+
# do not show inner progress bar
380+
forward_params["show_progress_bar"] = False
381+
382+
processed_documents = documents.map(
383+
self._process_documents,
384+
fn_kwargs=dict(
385+
preprocess_params=preprocess_params,
386+
dataloader_params=dataloader_params,
387+
forward_params=forward_params,
388+
postprocess_params=postprocess_params,
389+
),
390+
batched=True,
391+
**dataset_map_params,
392+
)
393+
else:
394+
processed_documents = self._process_documents(
395+
documents=documents,
396+
preprocess_params=preprocess_params,
397+
dataloader_params=dataloader_params,
398+
forward_params=forward_params,
399+
postprocess_params=postprocess_params,
400+
)
348401

349-
documents = self.postprocess(
350-
model_inputs=model_inputs,
351-
model_outputs=model_outputs,
352-
**postprocess_params,
353-
)
354402
if single_document:
355-
return documents[0]
403+
# TODO: fix "type: ignore" (if processed_documents is a Dataset, mypy assumes the result is Dict[Any, Any])
404+
processed_document: Document = processed_documents[0] # type: ignore
405+
return processed_document
356406
else:
357-
return documents
407+
return processed_documents

0 commit comments

Comments
 (0)