Skip to content

Commit ec3a3ea

Browse files
committed
do not allow parameters for documents.map to simplify pipeline
1 parent 8bd8471 commit ec3a3ea

File tree

1 file changed

+1
-11
lines changed

1 file changed

+1
-11
lines changed

src/pytorch_ie/pipeline.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def __init__(
7070
self._dataloader_params,
7171
self._forward_params,
7272
self._postprocess_params,
73-
self._dataset_map_params,
7473
) = self._sanitize_parameters(**kwargs)
7574

7675
def save_pretrained(self, save_directory: str):
@@ -167,7 +166,7 @@ def _ensure_tensor_on_device(self, inputs, device):
167166

168167
def _sanitize_parameters(
169168
self, **pipeline_parameters
170-
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
169+
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
171170
"""
172171
_sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
173172
methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`,
@@ -181,7 +180,6 @@ def _sanitize_parameters(
181180
dataloader_params = {}
182181
forward_parameters = {}
183182
postprocess_parameters: Dict[str, Any] = {}
184-
dataset_map_parameters = {}
185183

186184
# set preprocess parameters
187185
for p_name in ["document_batch_size"]:
@@ -203,16 +201,11 @@ def _sanitize_parameters(
203201
if p_name in pipeline_parameters:
204202
postprocess_parameters[p_name] = pipeline_parameters[p_name]
205203

206-
for p_name in ["document_batch_size"]:
207-
if p_name in pipeline_parameters:
208-
dataset_map_parameters["batch_size"] = pipeline_parameters[p_name]
209-
210204
return (
211205
preprocess_parameters,
212206
dataloader_params,
213207
forward_parameters,
214208
postprocess_parameters,
215-
dataset_map_parameters,
216209
)
217210

218211
def preprocess(
@@ -356,7 +349,6 @@ def __call__(
356349
dataloader_params,
357350
forward_params,
358351
postprocess_params,
359-
dataset_map_params,
360352
) = self._sanitize_parameters(**kwargs)
361353

362354
if "TOKENIZERS_PARALLELISM" not in os.environ:
@@ -370,7 +362,6 @@ def __call__(
370362
dataloader_params = {**self._dataloader_params, **dataloader_params}
371363
forward_params = {**self._forward_params, **forward_params}
372364
postprocess_params = {**self._postprocess_params, **postprocess_params}
373-
dataset_map_params = {**self._dataset_map_params, **dataset_map_params}
374365

375366
self.call_count += 1
376367
if self.call_count > 10 and self.device.type == "cuda":
@@ -408,7 +399,6 @@ def __call__(
408399
postprocess_params=postprocess_params,
409400
),
410401
batched=True,
411-
**dataset_map_params,
412402
)
413403
finally:
414404
if was_caching_enabled:

0 commit comments

Comments
 (0)