Skip to content

Commit 600d4b7

Browse files
authored
Pipeline: cleanup and improve docs (#443)
* add documentation * remove parameter binary_output because it was not used * remove parameter "shuffle" from dataloader_params because it would collide with fixed shuffle=False
1 parent 45f1469 commit 600d4b7

File tree

1 file changed

+42
-7
lines changed

1 file changed

+42
-7
lines changed

src/pytorch_ie/pipeline.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,15 @@ class Pipeline:
3737
3838
Pipeline supports running on CPU or GPU through the device argument (see below).
3939
40-
Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` )
41-
output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
42-
provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the
43-
pickle format.
40+
Args:
41+
model (:class:`~pytorch_ie.PyTorchIEModel`):
42+
The deep learning model to use for the pipeline.
43+
taskmodule (:class:`~pytorch_ie.TaskModule`): The taskmodule to use for encoding
44+
and decoding the documents.
45+
device (:obj:`Union[int, str]`, `optional`, defaults to :obj:`"cpu"`):
46+
The device to run the pipeline on. This can be a CPU device (:obj:`"cpu"`), a GPU
47+
device (:obj:`"cuda"`) or a specific GPU device (:obj:`"cuda:X"`, where :obj:`X`
48+
is the index of the GPU).
4449
"""
4550

4651
default_input_names = None
@@ -50,15 +55,13 @@ def __init__(
5055
model: PyTorchIEModel,
5156
taskmodule: TaskModule,
5257
device: Union[int, str] = "cpu",
53-
binary_output: bool = False,
5458
**kwargs,
5559
):
5660
self.taskmodule = taskmodule
5761
device_str = (
5862
("cpu" if device < 0 else f"cuda:{device}") if isinstance(device, int) else device
5963
)
6064
self.device = torch.device(device_str)
61-
self.binary_output = binary_output
6265

6366
# Module.to() returns just self, but moved to the device. This is not correctly
6467
# reflected in typing of PyTorch.
@@ -192,7 +195,7 @@ def _sanitize_parameters(
192195
forward_parameters[p_name] = pipeline_parameters[p_name]
193196

194197
# set dataloader parameters
195-
for p_name in ["batch_size", "num_workers", "shuffle"]:
198+
for p_name in ["batch_size", "num_workers"]:
196199
if p_name in pipeline_parameters:
197200
dataloader_params[p_name] = pipeline_parameters[p_name]
198201

@@ -299,6 +302,38 @@ def __call__(
299302
*args,
300303
**kwargs,
301304
) -> Union[Document, Sequence[Document]]:
305+
"""
306+
The __call__ method is the entry point for the pipeline. It will run the pipeline workflow in the following
307+
order:
308+
309+
1. Encode the documents
310+
2. Run the model forward pass(es) on the encodings
311+
3. Combine the model outputs with the inputs encodings and integrate them back into the documents
312+
313+
Args:
314+
documents (:obj:`Union[Document, Sequence[Document]]`): The documents to process. If a single document is
315+
passed, the output will be a single document. If a list of documents is passed, the output will be a
316+
list of documents.
317+
document_batch_size (:obj:`int`, `optional`): The batch size to use for encoding the documents with the
318+
taskmodule. If not provided, the default batch size of the taskmodule will be used.
319+
show_progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to show a progress bar
320+
during inference.
321+
fast_dev_run (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to run a fast development
322+
run. If set to :obj:`True`, only the first two model inputs will be processed.
323+
batch_size (:obj:`int`, `optional`, defaults to :obj:`1`): The batch size to use for the dataloader. If not
324+
provided, a batch size of 1 will be used.
325+
num_workers (:obj:`int`, `optional`, defaults to :obj:`8`): The number of workers to use for the dataloader.
326+
If not provided, 8 workers will be used.
327+
inplace (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to modify the input documents
328+
in place. Requires the input to be a mutable sequence of documents or a single document.
329+
330+
Note that all the arguments except `documents` can be set in the `__init__` method and/or overridden in the
331+
`__call__` method.
332+
333+
Returns:
334+
:obj:`Union[Document, Sequence[Document]]`: The processed documents. If a single document was passed, a
335+
single document will be returned. If a list of documents was passed, a list of documents will be returned.
336+
"""
302337
if args:
303338
logger.warning(f"Ignoring args : {args}")
304339
(

0 commit comments

Comments
 (0)