@@ -37,10 +37,15 @@ class Pipeline:
3737
3838 Pipeline supports running on CPU or GPU through the device argument (see below).
3939
40- Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` )
41- output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
42- provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the
43- pickle format.
40+ Args:
41+ model (:class:`~pytorch_ie.PyTorchIEModel`):
42+ The deep learning model to use for the pipeline.
43+ taskmodule (:class:`~pytorch_ie.TaskModule`): The taskmodule to use for encoding
44+ and decoding the documents.
45+ device (:obj:`Union[int, str]`, `optional`, defaults to :obj:`"cpu"`):
46+ The device to run the pipeline on. This can be a CPU device (:obj:`"cpu"`), a GPU
47+ device (:obj:`"cuda"`) or a specific GPU device (:obj:`"cuda:X"`, where :obj:`X`
48+ is the index of the GPU).
4449 """
4550
4651 default_input_names = None
@@ -50,15 +55,13 @@ def __init__(
5055 model : PyTorchIEModel ,
5156 taskmodule : TaskModule ,
5257 device : Union [int , str ] = "cpu" ,
53- binary_output : bool = False ,
5458 ** kwargs ,
5559 ):
5660 self .taskmodule = taskmodule
5761 device_str = (
5862 ("cpu" if device < 0 else f"cuda:{ device } " ) if isinstance (device , int ) else device
5963 )
6064 self .device = torch .device (device_str )
61- self .binary_output = binary_output
6265
6366 # Module.to() returns just self, but moved to the device. This is not correctly
6467 # reflected in typing of PyTorch.
@@ -192,7 +195,7 @@ def _sanitize_parameters(
192195 forward_parameters [p_name ] = pipeline_parameters [p_name ]
193196
194197 # set dataloader parameters
195- for p_name in ["batch_size" , "num_workers" , "shuffle" ]:
198+ for p_name in ["batch_size" , "num_workers" ]:
196199 if p_name in pipeline_parameters :
197200 dataloader_params [p_name ] = pipeline_parameters [p_name ]
198201
@@ -299,6 +302,38 @@ def __call__(
299302 * args ,
300303 ** kwargs ,
301304 ) -> Union [Document , Sequence [Document ]]:
305+ """
306+ The __call__ method is the entry point for the pipeline. It will run the pipeline workflow in the following
307+ order:
308+
309+ 1. Encode the documents
310+ 2. Run the model forward pass(es) on the encodings
311+ 3. Combine the model outputs with the inputs encodings and integrate them back into the documents
312+
313+ Args:
314+ documents (:obj:`Union[Document, Sequence[Document]]`): The documents to process. If a single document is
315+ passed, the output will be a single document. If a list of documents is passed, the output will be a
316+ list of documents.
317+ document_batch_size (:obj:`int`, `optional`): The batch size to use for encoding the documents with the
318+ taskmodule. If not provided, the default batch size of the taskmodule will be used.
319+ show_progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to show a progress bar
320+ during inference.
321+ fast_dev_run (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to run a fast development
322+ run. If set to :obj:`True`, only the first two model inputs will be processed.
323+ batch_size (:obj:`int`, `optional`, defaults to :obj:`1`): The batch size to use for the dataloader. If not
324+ provided, a batch size of 1 will be used.
325+ num_workers (:obj:`int`, `optional`, defaults to :obj:`8`): The number of workers to use for the dataloader.
326+ If not provided, 8 workers will be used.
327+ inplace (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to modify the input documents
328+ in place. Requires the input to be a mutable sequence of documents or a single document.
329+
330+ Note that all the arguments except `documents` can be set in the `__init__` method and/or overridden in the
331+ `__call__` method.
332+
333+ Returns:
334+ :obj:`Union[Document, Sequence[Document]]`: The processed documents. If a single document was passed, a
335+ single document will be returned. If a list of documents was passed, a list of documents will be returned.
336+ """
302337 if args :
303338 logger .warning (f"Ignoring args : { args } " )
304339 (
0 commit comments