How to perform inference with a model I trained with docTR? #568
-
I have trained a detection and recognition model on my dataset using your framework. Line 26 in b149266 def ocr_predictor(
det_arch: str = 'db_resnet50',
reco_arch: str = 'crnn_vgg16_bn',
pretrained: bool = False,
**kwargs: Any
) -> OCRPredictor:
"""End-to-end OCR architecture using one model for localization, and another for text recognition.
Example::
>>> import numpy as np
>>> from doctr.models import ocr_predictor
>>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True)
>>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
>>> out = model([input_page])
Args:
arch: name of the architecture to use ('db_sar_vgg', 'db_sar_resnet', 'db_crnn_vgg', 'db_crnn_resnet')
pretrained: If True, returns a model pre-trained on our OCR dataset
Returns:
OCR predictor
"""
return _predictor(det_arch, reco_arch, pretrained, **kwargs) What kwargs should I use?Also, even I put my models in cache_dir, the code goes on to download its own, also throws error for hashes unmatch |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 22 replies
-
Hello @K-for-Code 👋 The factory function import os
os.environ["USE_TORCH"] = "1"
import torch
from doctr.models.predictor import OCRPredictor
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.models.preprocessor import PreProcessor
# from doctr.models.utils import load_pretrained_params
# Instantiate your model here
det_model = ...
reco_model = ...
# Load the checkpoints you produced
# load_pretrained_params(det_model, "<URL_TO_DET_CHECKPOINT>")
# load_pretrained_params(reco_model, "<URL_TO_RECO_CHECKPOINT>")
# If using PyTorch
# import torch
det_params = torch.load("path/to/your/local/det_checkpoint.pt", map_location="cpu")
reco_params = torch.load("path/to/your/local/reco_checkpoint.pt", map_location="cpu")
det_model.load_state_dict(det_params)
reco_model.load_state_dict(reco_params)
# Ask the preprocessor of each task to resize and normalize similarly to your training
# cf. https://github.com/mindee/doctr/blob/main/references/detection/train_pytorch.py#L94 & https://github.com/mindee/doctr/blob/main/references/detection/train_pytorch.py#L109
det_predictor = DetectionPredictor(PreProcessor((1024, 1024), batch_size=1, mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), det_model)
# cf. https://github.com/mindee/doctr/blob/main/references/recognition/train_pytorch.py#L97 & https://github.com/mindee/doctr/blob/main/references/recognition/train_pytorch.py#L111
reco_predictor = RecognitionPredictor(PreProcessor((32, 128), preserve_aspect_ratio=True, batch_size=32, mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)), reco_model)
predictor = OCRPredictor(det_predictor, reco_predictor) Let me know if you still have questions! |
Beta Was this translation helpful? Give feedback.
-
Hi, how to perform inference on trained recognition models using tensorflow ? |
Beta Was this translation helpful? Give feedback.
-
Hi @fg-mindee , doc = DocumentFile.from_pdf("path/to/pdf_file.pdf")
# Detection model
det_model = db_resnet50(pretrained=False)
det_param = torch.load("./path/to/load_model.pt", map_location="cpu")
det_model.load_state_dict(det_param)
det_predictor = DetectionPredictor(PreProcessor((1024, 1024), batch_size=1, mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), det_model)
detection = det_predictor(doc)
#Recognition model
reco_model = crnn_vgg16_bn(pretrained=False)
reco_param = torch.load("./path/to/load_model.pt", map_location="cpu")
reco_model.load_state_dict(reco_param)
reco_predictor = RecognitionPredictor(PreProcessor((32, 128), preserve_aspect_ratio=True, batch_size=32, mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)), reco_model)
recognition = reco_predictor(detection) and of course I got an error 😅 because I don't know what supposed to do after getting the output from detection stage Here's the error : Thank You. |
Beta Was this translation helpful? Give feedback.
-
Hello @fg-mindee , Thanks, |
Beta Was this translation helpful? Give feedback.
Hello @K-for-Code 👋
The factory function
ocr_predictor
is a bit more high-level than that, but you can easily achieve what you want :)Here is a short example of to do this: