Skip to content

Commit

Permalink
[Enhancement] decouple batch_size to det_batch_size, rec_batch_size a…
Browse files Browse the repository at this point in the history
…nd kie_batch_size in MMOCRInferencer (#1801)

* decouple batch_size to det_batch_size, rec_batch_size, kie_batch_size and chunk_size in MMOCRInferencer

* remove chunk_size parameter

* add Optional keyword in function definitions and doc strings

* add det_batch_size, rec_batch_size, kie_batch_size in user_guides

* minor formatting
  • Loading branch information
hugotong6425 authored Mar 24, 2023
1 parent 22f40b7 commit c886936
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
3 changes: 3 additions & 0 deletions docs/en/user_guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,9 @@ Here are extensive lists of parameters that you can use.
| `inputs` | str/list/tuple/np.array | **required** | It can be a path to an image/a folder, an np array or a list/tuple (with img paths or np arrays) |
| `return_datasamples` | bool | False | Whether to return results as DataSamples. If False, the results will be packed into a dict. |
| `batch_size` | int | 1 | Inference batch size. |
| `det_batch_size` | int, optional | None | Inference batch size for text detection model. Overwrite batch_size if it is not None. |
| `rec_batch_size` | int, optional | None | Inference batch size for text recognition model. Overwrite batch_size if it is not None. |
| `kie_batch_size` | int, optional | None | Inference batch size for KIE model. Overwrite batch_size if it is not None. |
| `return_vis` | bool | False | Whether to return the visualization result. |
| `print_result` | bool | False | Whether to print the inference result to the console. |
| `show` | bool | False | Whether to display the visualization results in a popup window. |
Expand Down
3 changes: 3 additions & 0 deletions docs/zh_cn/user_guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ outputs
| `inputs` | str/list/tuple/np.array | **必需** | 它可以是一个图片/文件夹的路径,一个 numpy 数组,或者是一个包含图片路径或 numpy 数组的列表/元组 |
| `return_datasamples` | bool | False | 是否将结果作为 DataSample 返回。如果为 False,结果将被打包成一个字典。 |
| `batch_size` | int | 1 | 推理的批大小。 |
| `det_batch_size` | int, 可选 | None | 推理的批大小 (文本检测模型)。如果不为 None,则覆盖 batch_size。 |
| `rec_batch_size` | int, 可选 | None | 推理的批大小 (文本识别模型)。如果不为 None,则覆盖 batch_size。 |
| `kie_batch_size` | int, 可选 | None | 推理的批大小 (关键信息提取模型)。如果不为 None,则覆盖 batch_size。 |
| `return_vis` | bool | False | 是否返回可视化结果。 |
| `print_result` | bool | False | 是否将推理结果打印到控制台。 |
| `show` | bool | False | 是否在弹出窗口中显示可视化结果。 |
Expand Down
55 changes: 49 additions & 6 deletions mmocr/apis/inferencers/mmocr_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,34 +105,54 @@ def _inputs2ndarrray(self, inputs: List[InputsType]) -> List[np.ndarray]:
'supported yet.')
return new_inputs

def forward(self, inputs: InputsType, batch_size: int,
def forward(self,
inputs: InputsType,
batch_size: int = 1,
det_batch_size: Optional[int] = None,
rec_batch_size: Optional[int] = None,
kie_batch_size: Optional[int] = None,
**forward_kwargs) -> PredType:
"""Forward the inputs to the model.
Args:
inputs (InputsType): The inputs to be forwarded.
batch_size (int): Batch size. Defaults to 1.
det_batch_size (Optional[int]): Batch size for text detection
model. Overwrite batch_size if it is not None.
Defaults to None.
rec_batch_size (Optional[int]): Batch size for text recognition
model. Overwrite batch_size if it is not None.
Defaults to None.
kie_batch_size (Optional[int]): Batch size for KIE model.
Overwrite batch_size if it is not None.
Defaults to None.
Returns:
Dict: The prediction results. Possibly with keys "det", "rec", and
"kie"..
"""
result = {}
forward_kwargs['progress_bar'] = False
if det_batch_size is None:
det_batch_size = batch_size
if rec_batch_size is None:
rec_batch_size = batch_size
if kie_batch_size is None:
kie_batch_size = batch_size
if self.mode == 'rec':
# The extra list wrapper here is for the ease of postprocessing
self.rec_inputs = inputs
predictions = self.textrec_inferencer(
self.rec_inputs,
return_datasamples=True,
batch_size=batch_size,
batch_size=rec_batch_size,
**forward_kwargs)['predictions']
result['rec'] = [[p] for p in predictions]
elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie'
result['det'] = self.textdet_inferencer(
inputs,
return_datasamples=True,
batch_size=batch_size,
batch_size=det_batch_size,
**forward_kwargs)['predictions']
if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie'
result['rec'] = []
Expand All @@ -149,7 +169,7 @@ def forward(self, inputs: InputsType, batch_size: int,
self.textrec_inferencer(
self.rec_inputs,
return_datasamples=True,
batch_size=batch_size,
batch_size=rec_batch_size,
**forward_kwargs)['predictions'])
if self.mode == 'det_rec_kie':
self.kie_inputs = []
Expand All @@ -172,7 +192,7 @@ def forward(self, inputs: InputsType, batch_size: int,
result['kie'] = self.kie_inferencer(
self.kie_inputs,
return_datasamples=True,
batch_size=batch_size,
batch_size=kie_batch_size,
**forward_kwargs)['predictions']
return result

Expand Down Expand Up @@ -219,6 +239,9 @@ def __call__(
self,
inputs: InputsType,
batch_size: int = 1,
det_batch_size: Optional[int] = None,
rec_batch_size: Optional[int] = None,
kie_batch_size: Optional[int] = None,
out_dir: str = 'results/',
return_vis: bool = False,
save_vis: bool = False,
Expand All @@ -231,6 +254,15 @@ def __call__(
inputs (InputsType): Inputs for the inferencer. It can be a path
to image / image directory, or an array, or a list of these.
batch_size (int): Batch size. Defaults to 1.
det_batch_size (Optional[int]): Batch size for text detection
model. Overwrite batch_size if it is not None.
Defaults to None.
rec_batch_size (Optional[int]): Batch size for text recognition
model. Overwrite batch_size if it is not None.
Defaults to None.
kie_batch_size (Optional[int]): Batch size for KIE model.
Overwrite batch_size if it is not None.
Defaults to None.
out_dir (str): Output directory of results. Defaults to 'results/'.
return_vis (bool): Whether to return the visualization result.
Defaults to False.
Expand Down Expand Up @@ -269,12 +301,23 @@ def __call__(
**kwargs)

ori_inputs = self._inputs_to_list(inputs)
if det_batch_size is None:
det_batch_size = batch_size
if rec_batch_size is None:
rec_batch_size = batch_size
if kie_batch_size is None:
kie_batch_size = batch_size

chunked_inputs = super(BaseMMOCRInferencer,
self)._get_chunk_data(ori_inputs, batch_size)
results = {'predictions': [], 'visualization': []}
for ori_input in track(chunked_inputs, description='Inference'):
preds = self.forward(ori_input, batch_size, **forward_kwargs)
preds = self.forward(
ori_input,
det_batch_size=det_batch_size,
rec_batch_size=rec_batch_size,
kie_batch_size=kie_batch_size,
**forward_kwargs)
visualization = self.visualize(
ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs)
batch_res = self.postprocess(
Expand Down

0 comments on commit c886936

Please sign in to comment.