Skip to content

Commit

Permalink
add params for ocr det (#2701)
Browse files Browse the repository at this point in the history
* add params for ocr det

* add params for ocr det

* adddet v3

* adddet v3 infer

* adddet v3 infer

* adddet ocr params
  • Loading branch information
Sunting78 authored Dec 23, 2024
1 parent 10cdbcd commit 903b522
Show file tree
Hide file tree
Showing 9 changed files with 590 additions and 45 deletions.
40 changes: 40 additions & 0 deletions paddlex/configs/modules/text_detection/PP-OCRv3_mobile_det.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Global:
model: PP-OCRv3_mobile_det
mode: check_dataset # check_dataset/train/evaluate/predict
module: text_det
dataset_dir: "/paddle/dataset/paddlex/ocr_det/ocr_det_dataset_examples"
device: gpu:0,1,2,3
output: "output"

CheckDataset:
convert:
enable: False
src_dataset_type: null
split:
enable: False
train_percent: null
val_percent: null

Train:
epochs_iters: 100
batch_size: 4
learning_rate: 0.001
pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv3_mobile_det_pretrained.pdparams
resume_path: null
log_interval: 10
eval_interval: 1
save_interval: 1

Evaluate:
weight_path: "output/best_accuracy/best_accuracy.pdparams"
log_interval: 1

Export:
weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv3_mobile_det_pretrained.pdparams

Predict:
batch_size: 1
model_dir: "output/best_accuracy/inference"
input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png"
kernel_option:
run_mode: paddle
40 changes: 40 additions & 0 deletions paddlex/configs/modules/text_detection/PP-OCRv3_server_det.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Global:
model: PP-OCRv3_server_det
mode: check_dataset # check_dataset/train/evaluate/predict
module: text_det
dataset_dir: "/paddle/dataset/paddlex/ocr_det/ocr_det_dataset_examples"
device: gpu:0,1,2,3
output: "output"

CheckDataset:
convert:
enable: False
src_dataset_type: null
split:
enable: False
train_percent: null
val_percent: null

Train:
epochs_iters: 100
batch_size: 4
learning_rate: 0.001
pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv3_server_det_pretrained.pdparams
resume_path: null
log_interval: 10
eval_interval: 1
save_interval: 1

Evaluate:
weight_path: "output/best_accuracy/best_accuracy.pdparams"
log_interval: 1

Export:
weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv3_server_det_pretrained.pdparams

Predict:
batch_size: 1
model_dir: "output/best_accuracy/inference"
input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png"
kernel_option:
run_mode: paddle
93 changes: 77 additions & 16 deletions paddlex/inference/models_new/text_detection/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from typing import List, Union

from ....utils.func_register import FuncRegister
from ....modules.text_detection.model_list import MODELS
from ...common.batch_sampler import ImageBatchSampler
Expand All @@ -36,8 +39,27 @@ class TextDetPredictor(BasicPredictor):
_FUNC_MAP = {}
register = FuncRegister(_FUNC_MAP)

def __init__(self, *args, **kwargs):
def __init__(
self,
limit_side_len: Union[int, None] = None,
limit_type: Union[str, None] = None,
thresh: Union[float, None] = None,
box_thresh: Union[float, None] = None,
max_candidates: Union[int, None] = None,
unclip_ratio: Union[float, None] = None,
use_dilation: Union[bool, None] = None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)

self.limit_side_len = limit_side_len
self.limit_type = limit_type
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.use_dilation = use_dilation
self.pre_tfs, self.infer, self.post_op = self._build()

def _build_batch_sampler(self):
Expand Down Expand Up @@ -67,14 +89,37 @@ def _build(self):
post_op = self.build_postprocess(**self.config["PostProcess"])
return pre_tfs, infer, post_op

def process(self, batch_data):
def process(
self,
batch_data: List[Union[str, np.ndarray]],
limit_side_len: Union[int, None] = None,
limit_type: Union[str, None] = None,
thresh: Union[float, None] = None,
box_thresh: Union[float, None] = None,
max_candidates: Union[int, None] = None,
unclip_ratio: Union[float, None] = None,
use_dilation: Union[bool, None] = None,
):

batch_raw_imgs = self.pre_tfs["Read"](imgs=batch_data)
batch_imgs, batch_shapes = self.pre_tfs["Resize"](imgs=batch_raw_imgs)
batch_imgs, batch_shapes = self.pre_tfs["Resize"](
imgs=batch_raw_imgs,
limit_side_len=limit_side_len or self.limit_side_len,
limit_type=limit_type or self.limit_type,
)
batch_imgs = self.pre_tfs["Normalize"](imgs=batch_imgs)
batch_imgs = self.pre_tfs["ToCHW"](imgs=batch_imgs)
x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
batch_preds = self.infer(x=x)
polys, scores = self.post_op(batch_preds, batch_shapes)
polys, scores = self.post_op(
batch_preds,
batch_shapes,
thresh=thresh or self.thresh,
box_thresh=box_thresh or self.box_thresh,
max_candidates=max_candidates or self.max_candidates,
unclip_ratio=unclip_ratio or self.unclip_ratio,
use_dilation=use_dilation or self.use_dilation,
)
return {
"input_path": batch_data,
"input_img": batch_raw_imgs,
Expand All @@ -88,14 +133,29 @@ def build_readimg(self, channel_first, img_mode):
return "Read", ReadImage(format=img_mode)

@register("DetResizeForTest")
def build_resize(self, **kwargs):
def build_resize(
self,
limit_side_len: Union[int, None] = None,
limit_type: Union[str, None] = None,
**kwargs
):
# TODO: align to PaddleOCR
if self.model_name in ("PP-OCRv4_server_det", "PP-OCRv4_mobile_det"):
resize_long = kwargs.get("resize_long", 960)
return "Resize", DetResizeForTest(
limit_side_len=resize_long, limit_type="max"
)
return "Resize", DetResizeForTest(**kwargs)

if self.model_name in (
"PP-OCRv4_server_det",
"PP-OCRv4_mobile_det",
"PP-OCRv3_server_det",
"PP-OCRv3_mobile_det",
):
limit_side_len = self.limit_side_len or kwargs.get("resize_long", 960)
limit_type = self.limit_type or kwargs.get("limit_type", "max")
else:
limit_side_len = self.limit_side_len or kwargs.get("resize_long", 736)
limit_type = self.limit_type or kwargs.get("limit_type", "min")

return "Resize", DetResizeForTest(
limit_side_len=limit_side_len, limit_type=limit_type, **kwargs
)

@register("NormalizeImage")
def build_normalize(
Expand All @@ -117,11 +177,12 @@ def build_to_chw(self):
def build_postprocess(self, **kwargs):
if kwargs.get("name") == "DBPostProcess":
return DBPostProcess(
thresh=kwargs.get("thresh", 0.3),
box_thresh=kwargs.get("box_thresh", 0.7),
max_candidates=kwargs.get("max_candidates", 1000),
unclip_ratio=kwargs.get("unclip_ratio", 2.0),
use_dilation=kwargs.get("use_dilation", False),
thresh=self.thresh or kwargs.get("thresh", 0.3),
box_thresh=self.box_thresh or kwargs.get("box_thresh", 0.6),
max_candidates=self.max_candidates
or kwargs.get("max_candidates", 1000),
unclip_ratio=self.unclip_ratio or kwargs.get("unclip_ratio", 2.0),
use_dilation=self.use_dilation or kwargs.get("use_dilation", False),
score_mode=kwargs.get("score_mode", "fast"),
box_type=kwargs.get("box_type", "quad"),
)
Expand Down
Loading

0 comments on commit 903b522

Please sign in to comment.