Skip to content

Commit 1b4b687

Browse files
felixdittrich92fg-mindeecharlesmindee
authored
[onnx] classification models export (mindee#830)
* backup * onnx classification * fix: Fixed some ResNet architecture imprecisions (mindee#828) * feat: Added new resnets * feat: Added ResNet101 * fix: Fixed ResNet31 & ResNet34 wide * feat: Added new pretrained resnets * style: Fixed isort * fix: Fixed ResNet architectures * refactor: Refactored LinkNet * feat: Added more LinkNets * fix: Fixed MAGResNet * docs: Updated documentation * refactor: Removed ResNet101 * fix: Fixed warning * fix: Fixed a few bugs * test: Updated unittests * docs: Fixed docstrings * update with new models * feat: replace bce by focal loss in linknet loss (mindee#824) * feat: replace bce by focal loss in linknet loss * fix: requested changes * fix: mask reduction * fix: mask reduction * fix: loss reduction * fix: final adjustements * fix: final changes * Revert "feat: replace bce by focal loss in linknet loss (mindee#824)" This reverts commit 6511183. * Revert "fix: Fixed some ResNet architecture imprecisions (mindee#828)" This reverts commit 72e5e0d. * happy codacy * sapply suggestions * fix-setup * remove onnx from test req * move onnx deps ftm to torch * up * up * revert requirements * fix * update docstring * up Co-authored-by: F-G Fernandez <[email protected]> Co-authored-by: Charles Gaillard <[email protected]>
1 parent 243c354 commit 1b4b687

File tree

5 files changed

+83
-1
lines changed

5 files changed

+83
-1
lines changed

doctr/models/utils/pytorch.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from doctr.utils.data import download_from_url
1313

14-
__all__ = ['load_pretrained_params', 'conv_sequence_pt']
14+
__all__ = ['load_pretrained_params', 'conv_sequence_pt', 'export_classification_model_to_onnx']
1515

1616

1717
def load_pretrained_params(
@@ -87,3 +87,33 @@ def conv_sequence_pt(
8787
conv_seq.append(nn.ReLU(inplace=True))
8888

8989
return conv_seq
90+
91+
92+
def export_classification_model_to_onnx(model: nn.Module, exp_name: str, dummy_input: torch.Tensor) -> str:
93+
"""Export classification model to ONNX format.
94+
95+
>>> import torch
96+
>>> from doctr.models.classification import resnet18
97+
>>> from doctr.models.utils import export_classification_model_to_onnx
98+
>>> model = resnet18(pretrained=True)
99+
>>> export_classification_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
100+
101+
Args:
102+
model: the PyTorch model to be exported
103+
exp_name: the name for the exported model
104+
dummy_input: the dummy input to the model
105+
106+
Returns:
107+
the path to the exported model
108+
"""
109+
torch.onnx.export(
110+
model,
111+
dummy_input,
112+
f"{exp_name}.onnx",
113+
input_names=['input'],
114+
output_names=['logits'],
115+
dynamic_axes={'input': {0: 'batch_size'}, 'logits': {0: 'batch_size'}},
116+
export_params=True, opset_version=13, verbose=False
117+
)
118+
logging.info(f"Model exported to {exp_name}.onnx")
119+
return f"{exp_name}.onnx"

references/classification/train_pytorch.py

+10
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from doctr import transforms as T
2626
from doctr.datasets import VOCABS, CharacterGenerator
2727
from doctr.models import classification
28+
from doctr.models.utils import export_classification_model_to_onnx
2829
from utils import plot_recorder, plot_samples
2930

3031

@@ -334,6 +335,13 @@ def main(args):
334335
if args.wb:
335336
run.finish()
336337

338+
if args.export_onnx:
339+
print("Exporting model to ONNX...")
340+
dummy_batch = next(iter(val_loader))
341+
dummy_input = dummy_batch[0].cuda() if torch.cuda.is_available() else dummy_batch[0]
342+
model_path = export_classification_model_to_onnx(model, exp_name, dummy_input)
343+
print(f"Exported model saved in {model_path}")
344+
337345

338346
def parse_args():
339347
import argparse
@@ -378,6 +386,8 @@ def parse_args():
378386
help='Log to Weights & Biases')
379387
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
380388
help='Load pretrained parameters before starting the training')
389+
parser.add_argument('--export-onnx', dest='export_onnx', action='store_true',
390+
help='Export the model to ONNX')
381391
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
382392
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
383393
parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR')

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"pytest>=5.3.2",
6666
"coverage>=4.5.4",
6767
"hdf5storage>=0.1.18",
68+
"onnxruntime>=1.11.0",
6869
"requests>=2.20.0",
6970
"requirements-parser==0.2.0",
7071
# Quality
@@ -137,6 +138,7 @@ def deps_list(*pkgs):
137138
"coverage",
138139
"requests",
139140
"hdf5storage",
141+
"onnxruntime",
140142
"requirements-parser",
141143
)
142144

tests/pytorch/test_models_classification_pt.py

+39
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import os
2+
import tempfile
3+
14
import cv2
25
import numpy as np
6+
import onnxruntime
37
import pytest
48
import torch
59

610
from doctr.models import classification
711
from doctr.models.classification.predictor import CropOrientationPredictor
12+
from doctr.models.utils import export_classification_model_to_onnx
813

914

1015
def _test_classification(model, input_shape, output_size, batch_size=2):
@@ -98,3 +103,37 @@ def test_crop_orientation_model(mock_text_box):
98103
text_box_270 = np.rot90(text_box_0, 3)
99104
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True)
100105
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270]) == [0, 1, 2, 3]
106+
107+
108+
@pytest.mark.parametrize(
109+
"arch_name, input_shape, output_size",
110+
[
111+
["vgg16_bn_r", (3, 32, 32), (126,)],
112+
["resnet18", (3, 32, 32), (126,)],
113+
["resnet31", (3, 32, 32), (126,)],
114+
["resnet34", (3, 32, 32), (126,)],
115+
["resnet34_wide", (3, 32, 32), (126,)],
116+
["resnet50", (3, 32, 32), (126,)],
117+
["magc_resnet31", (3, 32, 32), (126,)],
118+
["mobilenet_v3_small", (3, 32, 32), (126,)],
119+
["mobilenet_v3_large", (3, 32, 32), (126,)],
120+
["mobilenet_v3_small_orientation", (3, 128, 128), (4,)],
121+
],
122+
)
123+
def test_models_onnx_export(arch_name, input_shape, output_size):
124+
# Model
125+
batch_size = 2
126+
model = classification.__dict__[arch_name](pretrained=True).eval()
127+
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
128+
with tempfile.TemporaryDirectory() as tmpdir:
129+
# Export
130+
model_path = export_classification_model_to_onnx(model,
131+
exp_name=os.path.join(tmpdir, "model"),
132+
dummy_input=dummy_input)
133+
assert os.path.exists(model_path)
134+
# Inference
135+
ort_session = onnxruntime.InferenceSession(os.path.join(tmpdir, "model.onnx"),
136+
providers=["CPUExecutionProvider"])
137+
ort_outs = ort_session.run(['logits'], {'input': dummy_input.numpy()})
138+
assert isinstance(ort_outs, list) and len(ort_outs) == 1
139+
assert ort_outs[0].shape == (batch_size, *output_size)

tests/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ pytest>=5.3.2
22
requests>=2.20.0
33
hdf5storage>=0.1.18
44
coverage>=4.5.4
5+
onnxruntime>=1.11.0
56
requirements-parser==0.2.0

0 commit comments

Comments
 (0)