We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c9806fa commit 9b31588Copy full SHA for 9b31588
doctr/models/obj_detection/factory/__init__.py
@@ -0,0 +1,4 @@
1
+from doctr.file_utils import is_torch_available
2
+
3
+if is_torch_available():
4
+ from .pytorch import *
doctr/models/obj_detection/factory/pytorch.py
@@ -0,0 +1,50 @@
+# Copyright (C) 2022, Mindee.
+# This program is licensed under the Apache License version 2.
+# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
5
6
+import json
7
+from typing import Any
8
9
+import torch
10
+from huggingface_hub import hf_hub_download
11
12
+from doctr.models import obj_detection
13
14
+__all__ = ['from_hub']
15
16
17
+def from_hub(repo_id: str, **kwargs: Any) -> torch.nn.Module:
18
+ """Instantiate & load a pretrained model from HF hub.
19
20
+ Example::
21
+ >>> from doctr.models.obj_detection import from_hub
22
+ >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn").eval()
23
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
24
+ >>> with torch.no_grad(): out = model(input_tensor)
25
26
+ Args:
27
+ repo_id: HuggingFace model hub repo
28
+ kwargs: kwargs of `hf_hub_download`
29
+ Returns:
30
+ Model loaded with the checkpoint
31
+ """
32
33
+ # Get the config
34
+ with open(hf_hub_download(repo_id, filename='config.json', **kwargs), 'rb') as f:
35
+ cfg = json.load(f)
36
37
+ model = obj_detection.__dict__[cfg['arch']](
38
+ pretrained=False,
39
+ image_mean=cfg['mean'],
40
+ image_std=cfg['std'],
41
+ max_size=cfg['input_shape'][-1],
42
+ num_classes=len(cfg['classes']),
43
+ )
44
45
+ # Load the checkpoint
46
+ state_dict = torch.load(hf_hub_download(repo_id, filename='pytorch_model.bin', **kwargs), map_location='cpu')
47
+ model.load_state_dict(state_dict)
48
+ model.cfg = cfg
49
50
+ return model
doctr/models/obj_detection/faster_rcnn/pytorch.py
@@ -17,9 +17,7 @@
'input_shape': (3, 1024, 1024),
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
- 'anchor_sizes': [32, 64, 128, 256, 512],
- 'anchor_aspect_ratios': (0.5, 1., 2.),
- 'num_classes': 5,
+ 'classes': ["background", "qr_code", "bar_code", "logo", "photo"],
'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/fasterrcnn_mobilenet_v3_large_fpn-d5b2490d.pt',
},
}
@@ -31,11 +29,11 @@ def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN:
"image_mean": default_cfgs[arch]['mean'],
"image_std": default_cfgs[arch]['std'],
"box_detections_per_img": 150,
- "box_score_thresh": 0.15,
+ "box_score_thresh": 0.5,
"box_positive_fraction": 0.35,
"box_nms_thresh": 0.2,
"rpn_nms_thresh": 0.2,
- "num_classes": default_cfgs[arch]['num_classes'],
+ "num_classes": len(default_cfgs[arch]['classes']),
# Build the model
mypy.ini
@@ -75,3 +75,7 @@ ignore_missing_imports = True
75
[mypy-h5py.*]
76
77
ignore_missing_imports = True
78
79
+[mypy-huggingface_hub.*]
80
81
+ignore_missing_imports = True
requirements-pt.txt
@@ -14,3 +14,4 @@ torchvision>=0.9.0
Pillow>=8.3.2
tqdm>=4.30.0
rapidfuzz>=1.6.0
+huggingface-hub>=0.4.0
requirements.txt
@@ -15,3 +15,4 @@ tqdm>=4.30.0
tensorflow-addons>=0.13.0
keras<2.7.0
setup.py
@@ -60,6 +60,7 @@
60
"tensorflow-addons>=0.13.0",
61
"rapidfuzz>=1.6.0",
62
"keras<2.7.0",
63
+ "huggingface-hub>=0.4.0",
64
# Testing
65
"pytest>=5.3.2",
66
"coverage>=4.5.4",
@@ -104,6 +105,7 @@ def deps_list(*pkgs):
104
105
deps["Pillow"],
106
deps["tqdm"],
107
deps["rapidfuzz"],
108
+ deps["huggingface-hub"],
109
]
110
111
extras = {}
tests/common/test_headers.py
@@ -7,18 +7,24 @@ def test_headers():
shebang = ["#!usr/bin/python\n"]
blank_line = "\n"
- _copyright_str = f"-{datetime.now().year}" if datetime.now().year > 2021 else ""
- copyright_notice = [f"# Copyright (C) 2021{_copyright_str}, Mindee.\n"]
+ starting_year = 2021
+ current_year = datetime.now().year
+ year_str = [current_year] + [f"{starting_year}-{current_year}" for year in range(starting_year, current_year)]
+ if starting_year == current_year:
+ year_str = year_str[:1]
+ copyright_notices = [[f"# Copyright (C) {_str}, Mindee.\n"] for _str in year_str]
license_notice = [
"# This program is licensed under the Apache License version 2.\n",
"# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.\n"
# Define all header options
- headers = [
+ headers = [[
shebang + [blank_line] + copyright_notice + [blank_line] + license_notice,
copyright_notice + [blank_line] + license_notice
- ]
+ ] for copyright_notice in copyright_notices]
+ headers = [_header for year_header in headers for _header in year_header]
excluded_files = ["version.py", "__init__.py"]
invalid_files = []
tests/pytorch/test_models_obj_detection_pt.py
@@ -1,7 +1,9 @@
import pytest
import torch
+from torchvision.models.detection import FasterRCNN
from doctr.models import obj_detection
+from doctr.models.obj_detection.factory import from_hub
@pytest.mark.parametrize(
@@ -32,3 +34,8 @@ def test_detection_models(arch_name, input_shape, pretrained):
target = [{k: v.cuda() for k, v in t.items()} for t in target]
out = model(input_tensor, target)
assert isinstance(out, dict) and all(isinstance(v, torch.Tensor) for v in out.values())
+def test_obj_det_from_hub():
+ model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn").eval()
+ assert isinstance(model, FasterRCNN)
tests/tensorflow/test_transforms_tf.py
@@ -451,7 +451,7 @@ def test_random_shadow(input_dtype, input_shape):
451
assert transformed.shape == input_shape
452
assert transformed.dtype == input_dtype
453
# The shadow will darken the picture
454
- assert tf.math.reduce_mean(input_t) > tf.math.reduce_mean(transformed)
+ assert tf.math.reduce_mean(input_t) >= tf.math.reduce_mean(transformed)
455
assert tf.math.reduce_all(transformed >= 0)
456
if input_dtype == tf.uint8:
457
assert tf.reduce_all(transformed <= 255)
0 commit comments