From 9045dcfc9c5c837b06fcda8e802f7cf1d95bd18c Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Thu, 29 Aug 2024 04:47:33 +0200 Subject: [PATCH] [orientation] Enable usage of custom trained orientation models (#1708) --- .../using_doctr/custom_models_training.rst | 73 ++++++++++++++++++- doctr/datasets/vocabs.py | 4 +- .../classification/mobilenet/pytorch.py | 2 + doctr/models/classification/zoo.py | 26 ++++--- doctr/models/factory/hub.py | 4 +- .../pytorch/test_models_classification_pt.py | 18 +++++ tests/pytorch/test_models_zoo_pt.py | 38 ++++++++++ .../test_models_classification_tf.py | 18 +++++ tests/tensorflow/test_models_zoo_tf.py | 38 ++++++++++ 9 files changed, 207 insertions(+), 14 deletions(-) diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index 6214dae2dc..ecf88d8116 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -1,7 +1,7 @@ Train your own model ==================== -If the pretrained models don't meet your specific needs, you have the option to train your own model using the doctr library. +If the pretrained models don't meet your specific needs, you have the option to train your own model using the docTR library. For details on the training process and the necessary data and data format, refer to the following links: - `detection `_ @@ -203,3 +203,74 @@ Load a model with customized Preprocessor: ) predictor = OCRPredictor(det_predictor, reco_predictor) + +Custom orientation classification models +---------------------------------------- + +If you work with rotated documents and make use of the orientation classification feature by passing one of the following arguments: + +* `assume_straight_pages=False` +* `detect_orientation=True` +* `straigten_pages=True` + +You can train your own orientation classification model using the docTR library. For details on the training process and the necessary data and data format, refer to the following link: + +- `orientation `_ + +**NOTE**: Currently we support only `mobilenet_v3_small` models for crop and page orientation classification. + +Loading your custom trained orientation classification model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. tabs:: + + .. tab:: TensorFlow + + .. code:: python3 + + from doctr.io import DocumentFile + from doctr.models import ocr_predictor, mobilenet_v3_small_page_orientation, mobilenet_v3_small_crop_orientation + from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor + + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False) + custom_page_orientation_model.load_weights("/weights") + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False) + custom_crop_orientation_model.load_weights("/weights") + + predictor = ocr_predictor( + pretrained=True, + assume_straight_pages=False, + straighten_pages=True, + detect_orientation=True, + ) + + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + .. tab:: PyTorch + + .. code:: python3 + + import torch + from doctr.io import DocumentFile + from doctr.models import ocr_predictor, mobilenet_v3_small_page_orientation, mobilenet_v3_small_crop_orientation + from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor + + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False) + page_params = torch.load('', map_location="cpu") + custom_page_orientation_model.load_state_dict(page_params) + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False) + crop_params = torch.load('', map_location="cpu") + custom_crop_orientation_model.load_state_dict(crop_params) + + predictor = ocr_predictor( + pretrained=True, + assume_straight_pages=False, + straighten_pages=True, + detect_orientation=True, + ) + + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) diff --git a/doctr/datasets/vocabs.py b/doctr/datasets/vocabs.py index 91c5af7950..94942d58e3 100644 --- a/doctr/datasets/vocabs.py +++ b/doctr/datasets/vocabs.py @@ -60,7 +60,9 @@ VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪" VOCABS["hindi"] = VOCABS["hindi_letters"] + VOCABS["hindi_digits"] + VOCABS["hindi_punctuation"] VOCABS["bangla"] = VOCABS["bangla_letters"] + VOCABS["bangla_digits"] -VOCABS["ukrainian"] = VOCABS["generic_cyrillic_letters"] + VOCABS["digits"] + VOCABS["punctuation"] + VOCABS["currency"] + "ґіїєҐІЇЄ₴" +VOCABS["ukrainian"] = ( + VOCABS["generic_cyrillic_letters"] + VOCABS["digits"] + VOCABS["punctuation"] + VOCABS["currency"] + "ґіїєҐІЇЄ₴" +) VOCABS["multilingual"] = "".join( dict.fromkeys( VOCABS["french"] diff --git a/doctr/models/classification/mobilenet/pytorch.py b/doctr/models/classification/mobilenet/pytorch.py index 615664854d..18470fdf11 100644 --- a/doctr/models/classification/mobilenet/pytorch.py +++ b/doctr/models/classification/mobilenet/pytorch.py @@ -9,12 +9,14 @@ from typing import Any, Dict, List, Optional from torchvision.models import mobilenetv3 +from torchvision.models.mobilenetv3 import MobileNetV3 from doctr.datasets import VOCABS from ...utils import load_pretrained_params __all__ = [ + "MobileNetV3", "mobilenet_v3_small", "mobilenet_v3_small_r", "mobilenet_v3_large", diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 9368bb225d..fccd5b5979 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -34,15 +34,21 @@ ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"] -def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor: - if arch not in ORIENTATION_ARCHS: - raise ValueError(f"unknown architecture '{arch}'") +def _orientation_predictor(arch: Any, pretrained: bool, model_type: str, **kwargs: Any) -> OrientationPredictor: + if isinstance(arch, str): + if arch not in ORIENTATION_ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + # Load directly classifier from backbone + _model = classification.__dict__[arch](pretrained=pretrained) + else: + if not isinstance(arch, classification.MobileNetV3): + raise ValueError(f"unknown architecture: {type(arch)}") + _model = arch - # Load directly classifier from backbone - _model = classification.__dict__[arch](pretrained=pretrained) kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) - kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) + kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] predictor = OrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model @@ -51,7 +57,7 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient def crop_orientation_predictor( - arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any + arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any ) -> OrientationPredictor: """Crop orientation classification architecture. @@ -71,11 +77,11 @@ def crop_orientation_predictor( ------- OrientationPredictor """ - return _orientation_predictor(arch, pretrained, **kwargs) + return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs) def page_orientation_predictor( - arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any + arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any ) -> OrientationPredictor: """Page orientation classification architecture. @@ -95,4 +101,4 @@ def page_orientation_predictor( ------- OrientationPredictor """ - return _orientation_predictor(arch, pretrained, **kwargs) + return _orientation_predictor(arch, pretrained, model_type="page", **kwargs) diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index a6c3f89322..41cd91579a 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -33,7 +33,7 @@ AVAILABLE_ARCHS = { - "classification": models.classification.zoo.ARCHS, + "classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS, "detection": models.detection.zoo.ARCHS, "recognition": models.recognition.zoo.ARCHS, } @@ -174,7 +174,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name) repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False) - repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True) + repo = Repository(local_dir=local_cache_dir, clone_from=repo_url) with repo.commit(commit_message): _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task) diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index d2dbe5087a..f35a1ac9de 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -134,6 +134,15 @@ def test_crop_orientation_model(mock_text_box): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test custom model loading + classifier = classification.crop_orientation_predictor( + classification.mobilenet_v3_small_crop_orientation(pretrained=True) + ) + assert isinstance(classifier, OrientationPredictor) + + with pytest.raises(ValueError): + _ = classification.crop_orientation_predictor(classification.textnet_tiny(pretrained=True)) + def test_page_orientation_model(mock_payslip): text_box_0 = cv2.imread(mock_payslip) @@ -147,6 +156,15 @@ def test_page_orientation_model(mock_payslip): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test custom model loading + classifier = classification.page_orientation_predictor( + classification.mobilenet_v3_small_page_orientation(pretrained=True) + ) + assert isinstance(classifier, OrientationPredictor) + + with pytest.raises(ValueError): + _ = classification.page_orientation_predictor(classification.textnet_tiny(pretrained=True)) + @pytest.mark.parametrize( "arch_name, input_shape, output_size", diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 0cac9724ee..9be66edd7b 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -7,6 +7,8 @@ from doctr.io import Document, DocumentFile from doctr.io.elements import KIEDocument from doctr.models import detection, recognition +from doctr.models.classification import mobilenet_v3_small_crop_orientation, mobilenet_v3_small_page_orientation +from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor from doctr.models.detection.predictor import DetectionPredictor from doctr.models.detection.zoo import detection_predictor from doctr.models.kie_predictor import KIEPredictor @@ -85,6 +87,24 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa orientation = 0 assert out.pages[0].orientation["value"] == orientation + # Test with custom orientation models + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True) + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True) + + if assume_straight_pages: + if predictor.detect_orientation or predictor.straighten_pages: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + else: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + out = predictor(doc) + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + def test_trained_ocr_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) @@ -209,6 +229,24 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa orientation = 0 assert out.pages[0].orientation["value"] == orientation + # Test with custom orientation models + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True) + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True) + + if assume_straight_pages: + if predictor.detect_orientation or predictor.straighten_pages: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + else: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + out = predictor(doc) + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + def test_trained_kie_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index 8b2c720328..77eb8253ca 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -113,6 +113,15 @@ def test_crop_orientation_model(mock_text_box): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test custom model loading + classifier = classification.crop_orientation_predictor( + classification.mobilenet_v3_small_crop_orientation(pretrained=True) + ) + assert isinstance(classifier, OrientationPredictor) + + with pytest.raises(ValueError): + _ = classification.crop_orientation_predictor(classification.textnet_tiny(pretrained=True)) + def test_page_orientation_model(mock_payslip): text_box_0 = cv2.imread(mock_payslip) @@ -126,6 +135,15 @@ def test_page_orientation_model(mock_payslip): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test custom model loading + classifier = classification.page_orientation_predictor( + classification.mobilenet_v3_small_page_orientation(pretrained=True) + ) + assert isinstance(classifier, OrientationPredictor) + + with pytest.raises(ValueError): + _ = classification.page_orientation_predictor(classification.textnet_tiny(pretrained=True)) + # temporarily fix to avoid killing the CI (tf2onnx v1.14 memory leak issue) # ref.: https://github.com/mindee/doctr/pull/1201 diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index f20cb21f5c..4b7e606563 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -6,6 +6,8 @@ from doctr.io import Document, DocumentFile from doctr.io.elements import KIEDocument from doctr.models import detection, recognition +from doctr.models.classification import mobilenet_v3_small_crop_orientation, mobilenet_v3_small_page_orientation +from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor from doctr.models.detection.predictor import DetectionPredictor from doctr.models.detection.zoo import detection_predictor from doctr.models.kie_predictor import KIEPredictor @@ -84,6 +86,24 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa language = "unknown" assert out.pages[0].language["value"] == language + # Test with custom orientation models + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True) + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True) + + if assume_straight_pages: + if predictor.detect_orientation or predictor.straighten_pages: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + else: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + out = predictor(doc) + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + def test_trained_ocr_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) @@ -207,6 +227,24 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa language = "unknown" assert out.pages[0].language["value"] == language + # Test with custom orientation models + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True) + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True) + + if assume_straight_pages: + if predictor.detect_orientation or predictor.straighten_pages: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + else: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + out = predictor(doc) + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + def test_trained_kie_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip)