Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9b31588

Browse files
authoredMar 8, 2022
feat: Added loading method for PyTorch artefact detection models from HF Hub (mindee#836)
* refactor: Refactored FasterRCNN * feat: Added factory method from_hub * chore: Updated requirements * test: Added unittest * chore: Updated mypy config * test: Updated unittests * test: Fixed unittest * test: Fixed unittest * feat: Added cfg to model * test: Fixed unittest
1 parent c9806fa commit 9b31588

File tree

10 files changed

+83
-10
lines changed

10 files changed

+83
-10
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from doctr.file_utils import is_torch_available
2+
3+
if is_torch_available():
4+
from .pytorch import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (C) 2022, Mindee.
2+
3+
# This program is licensed under the Apache License version 2.
4+
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
5+
6+
import json
7+
from typing import Any
8+
9+
import torch
10+
from huggingface_hub import hf_hub_download
11+
12+
from doctr.models import obj_detection
13+
14+
__all__ = ['from_hub']
15+
16+
17+
def from_hub(repo_id: str, **kwargs: Any) -> torch.nn.Module:
18+
"""Instantiate & load a pretrained model from HF hub.
19+
20+
Example::
21+
>>> from doctr.models.obj_detection import from_hub
22+
>>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn").eval()
23+
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
24+
>>> with torch.no_grad(): out = model(input_tensor)
25+
26+
Args:
27+
repo_id: HuggingFace model hub repo
28+
kwargs: kwargs of `hf_hub_download`
29+
Returns:
30+
Model loaded with the checkpoint
31+
"""
32+
33+
# Get the config
34+
with open(hf_hub_download(repo_id, filename='config.json', **kwargs), 'rb') as f:
35+
cfg = json.load(f)
36+
37+
model = obj_detection.__dict__[cfg['arch']](
38+
pretrained=False,
39+
image_mean=cfg['mean'],
40+
image_std=cfg['std'],
41+
max_size=cfg['input_shape'][-1],
42+
num_classes=len(cfg['classes']),
43+
)
44+
45+
# Load the checkpoint
46+
state_dict = torch.load(hf_hub_download(repo_id, filename='pytorch_model.bin', **kwargs), map_location='cpu')
47+
model.load_state_dict(state_dict)
48+
model.cfg = cfg
49+
50+
return model

‎doctr/models/obj_detection/faster_rcnn/pytorch.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
'input_shape': (3, 1024, 1024),
1818
'mean': (0.485, 0.456, 0.406),
1919
'std': (0.229, 0.224, 0.225),
20-
'anchor_sizes': [32, 64, 128, 256, 512],
21-
'anchor_aspect_ratios': (0.5, 1., 2.),
22-
'num_classes': 5,
20+
'classes': ["background", "qr_code", "bar_code", "logo", "photo"],
2321
'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/fasterrcnn_mobilenet_v3_large_fpn-d5b2490d.pt',
2422
},
2523
}
@@ -31,11 +29,11 @@ def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN:
3129
"image_mean": default_cfgs[arch]['mean'],
3230
"image_std": default_cfgs[arch]['std'],
3331
"box_detections_per_img": 150,
34-
"box_score_thresh": 0.15,
32+
"box_score_thresh": 0.5,
3533
"box_positive_fraction": 0.35,
3634
"box_nms_thresh": 0.2,
3735
"rpn_nms_thresh": 0.2,
38-
"num_classes": default_cfgs[arch]['num_classes'],
36+
"num_classes": len(default_cfgs[arch]['classes']),
3937
}
4038

4139
# Build the model

‎mypy.ini

+4
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,7 @@ ignore_missing_imports = True
7575
[mypy-h5py.*]
7676

7777
ignore_missing_imports = True
78+
79+
[mypy-huggingface_hub.*]
80+
81+
ignore_missing_imports = True

‎requirements-pt.txt

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ torchvision>=0.9.0
1414
Pillow>=8.3.2
1515
tqdm>=4.30.0
1616
rapidfuzz>=1.6.0
17+
huggingface-hub>=0.4.0

‎requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ tqdm>=4.30.0
1515
tensorflow-addons>=0.13.0
1616
rapidfuzz>=1.6.0
1717
keras<2.7.0
18+
huggingface-hub>=0.4.0

‎setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"tensorflow-addons>=0.13.0",
6161
"rapidfuzz>=1.6.0",
6262
"keras<2.7.0",
63+
"huggingface-hub>=0.4.0",
6364
# Testing
6465
"pytest>=5.3.2",
6566
"coverage>=4.5.4",
@@ -104,6 +105,7 @@ def deps_list(*pkgs):
104105
deps["Pillow"],
105106
deps["tqdm"],
106107
deps["rapidfuzz"],
108+
deps["huggingface-hub"],
107109
]
108110

109111
extras = {}

‎tests/common/test_headers.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,24 @@ def test_headers():
77
shebang = ["#!usr/bin/python\n"]
88
blank_line = "\n"
99

10-
_copyright_str = f"-{datetime.now().year}" if datetime.now().year > 2021 else ""
11-
copyright_notice = [f"# Copyright (C) 2021{_copyright_str}, Mindee.\n"]
10+
starting_year = 2021
11+
current_year = datetime.now().year
12+
year_str = [current_year] + [f"{starting_year}-{current_year}" for year in range(starting_year, current_year)]
13+
if starting_year == current_year:
14+
year_str = year_str[:1]
15+
16+
copyright_notices = [[f"# Copyright (C) {_str}, Mindee.\n"] for _str in year_str]
1217
license_notice = [
1318
"# This program is licensed under the Apache License version 2.\n",
1419
"# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.\n"
1520
]
1621

1722
# Define all header options
18-
headers = [
23+
headers = [[
1924
shebang + [blank_line] + copyright_notice + [blank_line] + license_notice,
2025
copyright_notice + [blank_line] + license_notice
21-
]
26+
] for copyright_notice in copyright_notices]
27+
headers = [_header for year_header in headers for _header in year_header]
2228

2329
excluded_files = ["version.py", "__init__.py"]
2430
invalid_files = []

‎tests/pytorch/test_models_obj_detection_pt.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import pytest
22
import torch
3+
from torchvision.models.detection import FasterRCNN
34

45
from doctr.models import obj_detection
6+
from doctr.models.obj_detection.factory import from_hub
57

68

79
@pytest.mark.parametrize(
@@ -32,3 +34,8 @@ def test_detection_models(arch_name, input_shape, pretrained):
3234
target = [{k: v.cuda() for k, v in t.items()} for t in target]
3335
out = model(input_tensor, target)
3436
assert isinstance(out, dict) and all(isinstance(v, torch.Tensor) for v in out.values())
37+
38+
39+
def test_obj_det_from_hub():
40+
model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn").eval()
41+
assert isinstance(model, FasterRCNN)

‎tests/tensorflow/test_transforms_tf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def test_random_shadow(input_dtype, input_shape):
451451
assert transformed.shape == input_shape
452452
assert transformed.dtype == input_dtype
453453
# The shadow will darken the picture
454-
assert tf.math.reduce_mean(input_t) > tf.math.reduce_mean(transformed)
454+
assert tf.math.reduce_mean(input_t) >= tf.math.reduce_mean(transformed)
455455
assert tf.math.reduce_all(transformed >= 0)
456456
if input_dtype == tf.uint8:
457457
assert tf.reduce_all(transformed <= 255)

0 commit comments

Comments
 (0)
Please sign in to comment.