Skip to content
This repository was archived by the owner on Dec 2, 2022. It is now read-only.

Commit 3a31dc8

Browse files
authored
add more generic image reading (#28)
* add more generic image reading * update readme
1 parent 02d500d commit 3a31dc8

File tree

7 files changed

+114
-55
lines changed

7 files changed

+114
-55
lines changed

README.md

-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ prediction_result = get_prediction(
8989

9090
# export detected text regions
9191
exported_file_paths = export_detected_regions(
92-
image_path=image_path,
9392
image=image,
9493
regions=prediction_result["boxes"],
9594
output_dir=output_dir,
@@ -98,7 +97,6 @@ exported_file_paths = export_detected_regions(
9897

9998
# export heatmap, detection points, box visualization
10099
export_extra_results(
101-
image_path=image_path,
102100
image=image,
103101
regions=prediction_result["boxes"],
104102
heatmaps=prediction_result["heatmaps"],

craft_text_detector/__init__.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import absolute_import
22

3+
import os
4+
35
import craft_text_detector.craft_utils as craft_utils
46
import craft_text_detector.file_utils as file_utils
57
import craft_text_detector.image_utils as image_utils
68
import craft_text_detector.predict as predict
79
import craft_text_detector.torch_utils as torch_utils
810

9-
__version__ = "0.3.3"
11+
__version__ = "0.3.4"
1012

1113

1214
__all__ = [
@@ -45,7 +47,6 @@ def __init__(
4547
):
4648
"""
4749
Arguments:
48-
image_path: path to the image to be processed
4950
output_dir: path to the results to be exported
5051
rectify: rectify detected polygon by affine transform
5152
export_extra: export heatmap, detection points, box visualization
@@ -102,21 +103,26 @@ def unload_refinenet_model(self):
102103
self.refine_net = None
103104
empty_cuda_cache()
104105

105-
def detect_text(self, image_path):
106+
def detect_text(self, image, image_path=None):
106107
"""
107108
Arguments:
108-
image_path: path to the image to be processed
109+
image: path to the image to be processed or numpy array or PIL image
110+
109111
Output:
110-
{"masks": lists of predicted masks 2d as bool array,
111-
"boxes": list of coords of points of predicted boxes,
112-
"boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size,
113-
"polys_as_ratios": list of coords of points of predicted polys as ratios of image size,
114-
"heatmaps": visualization of the detected characters/links,
115-
"text_crop_paths": list of paths of the exported text boxes/polys,
116-
"times": elapsed times of the sub modules, in seconds}
112+
{
113+
"masks": lists of predicted masks 2d as bool array,
114+
"boxes": list of coords of points of predicted boxes,
115+
"boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size,
116+
"polys_as_ratios": list of coords of points of predicted polys as ratios of image size,
117+
"heatmaps": visualization of the detected characters/links,
118+
"text_crop_paths": list of paths of the exported text boxes/polys,
119+
"times": elapsed times of the sub modules, in seconds
120+
}
117121
"""
118-
# load image
119-
image = read_image(image_path)
122+
123+
if image_path is not None:
124+
print("Argument 'image_path' is deprecated, use 'image' instead.")
125+
image = image_path
120126

121127
# perform prediction
122128
prediction_result = get_prediction(
@@ -142,10 +148,14 @@ def detect_text(self, image_path):
142148
prediction_result["text_crop_paths"] = []
143149
if self.output_dir is not None:
144150
# export detected text regions
151+
if type(image) == str:
152+
file_name, file_ext = os.path.splitext(os.path.basename(image))
153+
else:
154+
file_name = "image"
145155
exported_file_paths = export_detected_regions(
146-
image_path=image_path,
147156
image=image,
148157
regions=regions,
158+
file_name=file_name,
149159
output_dir=self.output_dir,
150160
rectify=self.rectify,
151161
)
@@ -154,10 +164,10 @@ def detect_text(self, image_path):
154164
# export heatmap, detection points, box visualization
155165
if self.export_extra:
156166
export_extra_results(
157-
image_path=image_path,
158167
image=image,
159168
regions=regions,
160169
heatmaps=prediction_result["heatmaps"],
170+
file_name=file_name,
161171
output_dir=self.output_dir,
162172
)
163173

craft_text_detector/file_utils.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import gdown
66
import numpy as np
77

8+
from craft_text_detector.image_utils import read_image
9+
810

911
def download(url: str, save_path: str):
1012
"""
@@ -158,22 +160,27 @@ def export_detected_region(image, poly, file_path, rectify=True):
158160

159161

160162
def export_detected_regions(
161-
image_path, image, regions, output_dir: str = "output/", rectify: bool = False
163+
image,
164+
regions,
165+
file_name: str = "image",
166+
output_dir: str = "output/",
167+
rectify: bool = False,
162168
):
163169
"""
164170
Arguments:
165-
image_path: path to original image
166-
image: full/original image
171+
image: path to the image to be processed or numpy array or PIL image
167172
regions: list of bboxes or polys
173+
file_name (str): export image file name
168174
output_dir: folder to be exported
169175
rectify: rectify detected polygon by affine transform
170176
"""
177+
178+
# read/convert image
179+
image = read_image(image)
180+
171181
# deepcopy image so that original is not altered
172182
image = copy.deepcopy(image)
173183

174-
# get file name
175-
file_name, file_ext = os.path.splitext(os.path.basename(image_path))
176-
177184
# create crops dir
178185
crops_dir = os.path.join(output_dir, file_name + "_crops")
179186
create_dir(crops_dir)
@@ -194,34 +201,32 @@ def export_detected_regions(
194201

195202

196203
def export_extra_results(
197-
image_path,
198204
image,
199205
regions,
200206
heatmaps,
207+
file_name: str = "image",
201208
output_dir="output/",
202209
verticals=None,
203210
texts=None,
204211
):
205-
""" save text detection result one by one
212+
"""save text detection result one by one
206213
Args:
207-
image_path (str): image file name
208-
image (array): raw image context
214+
image: path to the image to be processed or numpy array or PIL image
215+
file_name (str): export image file name
209216
boxes (array): array of result file
210217
Shape: [num_detections, 4] for BB output / [num_detections, 4]
211218
for QUAD output
212219
Return:
213220
None
214221
"""
215-
image = np.array(image)
216-
217-
# make result file list
218-
filename, file_ext = os.path.splitext(os.path.basename(image_path))
222+
# read/convert image
223+
image = read_image(image)
219224

220225
# result directory
221-
res_file = os.path.join(output_dir, filename + "_text_detection.txt")
222-
res_img_file = os.path.join(output_dir, filename + "_text_detection.png")
223-
text_heatmap_file = os.path.join(output_dir, filename + "_text_score_heatmap.png")
224-
link_heatmap_file = os.path.join(output_dir, filename + "_link_score_heatmap.png")
226+
res_file = os.path.join(output_dir, file_name + "_text_detection.txt")
227+
res_img_file = os.path.join(output_dir, file_name + "_text_detection.png")
228+
text_heatmap_file = os.path.join(output_dir, file_name + "_text_score_heatmap.png")
229+
link_heatmap_file = os.path.join(output_dir, file_name + "_link_score_heatmap.png")
225230

226231
# create output dir
227232
create_dir(output_dir)

craft_text_detector/image_utils.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,24 @@
77
import numpy as np
88

99

10-
def read_image(img_file):
11-
img = cv2.imread(img_file)
12-
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
13-
# following two cases are not explained in the original repo
14-
if img.shape[0] == 2:
15-
img = img[0]
16-
if img.shape[2] == 4:
17-
img = img[:, :, :3]
10+
def read_image(image):
11+
if type(image) == str:
12+
img = cv2.imread(image)
13+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
14+
15+
elif type(image) == bytes:
16+
nparr = np.frombuffer(image, np.uint8)
17+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
18+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19+
20+
elif type(image) == np.ndarray:
21+
if len(image.shape) == 2: # grayscale
22+
img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
23+
elif len(image.shape) == 3 and image.shape[2] == 3: # BGRscale
24+
img = image
25+
elif len(image.shape) == 3 and image.shape[2] == 4: # RGBAscale
26+
img = image[:, :, :3]
27+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
1828

1929
return img
2030

craft_text_detector/predict.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def get_prediction(
2222
):
2323
"""
2424
Arguments:
25-
image: image to be processed
25+
image: path to the image to be processed or numpy array or PIL image
2626
output_dir: path to the results to be exported
2727
craft_net: craft net model
2828
refine_net: refine net model
@@ -43,6 +43,9 @@ def get_prediction(
4343
"""
4444
t0 = time.time()
4545

46+
# read/convert image
47+
image = image_utils.read_image(image)
48+
4649
# resize
4750
img_resized, target_ratio, size_heatmap = image_utils.resize_aspect_ratio(
4851
image, long_size, interpolation=cv2.INTER_LINEAR

tests/test_craft.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
23
from craft_text_detector import Craft
34

45

@@ -75,7 +76,7 @@ def test_detect_text(self):
7576
crop_type="poly",
7677
)
7778
# detect text
78-
prediction_result = craft.detect_text(image_path=self.image_path)
79+
prediction_result = craft.detect_text(image=self.image_path)
7980

8081
self.assertEqual(len(prediction_result["boxes"]), 52)
8182
self.assertEqual(len(prediction_result["boxes"][0]), 4)
@@ -96,7 +97,7 @@ def test_detect_text(self):
9697
crop_type="poly",
9798
)
9899
# detect text
99-
prediction_result = craft.detect_text(image_path=self.image_path)
100+
prediction_result = craft.detect_text(image=self.image_path)
100101

101102
self.assertEqual(len(prediction_result["boxes"]), 19)
102103
self.assertEqual(len(prediction_result["boxes"][0]), 4)
@@ -117,7 +118,7 @@ def test_detect_text(self):
117118
crop_type="box",
118119
)
119120
# detect text
120-
prediction_result = craft.detect_text(image_path=self.image_path)
121+
prediction_result = craft.detect_text(image=self.image_path)
121122

122123
self.assertEqual(len(prediction_result["boxes"]), 52)
123124
self.assertEqual(len(prediction_result["boxes"][0]), 4)
@@ -138,7 +139,7 @@ def test_detect_text(self):
138139
crop_type="box",
139140
)
140141
# detect text
141-
prediction_result = craft.detect_text(image_path=self.image_path)
142+
prediction_result = craft.detect_text(image=self.image_path)
142143

143144
self.assertEqual(len(prediction_result["boxes"]), 19)
144145
self.assertEqual(len(prediction_result["boxes"][0]), 4)

tests/test_helpers.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from tempfile import TemporaryDirectory
44

55
from craft_text_detector import (
6-
read_image,
7-
load_craftnet_model,
8-
load_refinenet_model,
9-
get_prediction,
106
export_detected_regions,
117
export_extra_results,
8+
get_prediction,
9+
load_craftnet_model,
10+
load_refinenet_model,
11+
read_image,
1212
)
1313

1414

@@ -20,18 +20,17 @@ def test_load_craftnet_model(self):
2020
self.assertTrue(craft_net)
2121

2222
with TemporaryDirectory() as dir_name:
23-
weight_path = Path(dir_name, 'weights.pth')
23+
weight_path = Path(dir_name, "weights.pth")
2424
self.assertFalse(weight_path.is_file())
2525
load_craftnet_model(cuda=False, weight_path=weight_path)
2626
self.assertTrue(weight_path.is_file())
2727

28-
2928
def test_load_refinenet_model(self):
3029
refine_net = load_refinenet_model(cuda=False)
3130
self.assertTrue(refine_net)
3231

3332
with TemporaryDirectory() as dir_name:
34-
weight_path = Path(dir_name, 'weights.pth')
33+
weight_path = Path(dir_name, "weights.pth")
3534
self.assertFalse(weight_path.is_file())
3635
load_refinenet_model(cuda=False, weight_path=weight_path)
3736
self.assertTrue(weight_path.is_file())
@@ -73,6 +72,39 @@ def test_get_prediction(self):
7372
prediction_result["heatmaps"]["text_score_heatmap"].shape, (240, 368, 3)
7473
)
7574

75+
def test_get_prediction_without_read_image(self):
76+
# set image filepath
77+
image = self.image_path
78+
79+
# load models
80+
craft_net = load_craftnet_model()
81+
refine_net = None
82+
83+
# perform prediction
84+
text_threshold = 0.9
85+
link_threshold = 0.2
86+
low_text = 0.2
87+
cuda = False
88+
prediction_result = get_prediction(
89+
image=image,
90+
craft_net=craft_net,
91+
refine_net=refine_net,
92+
text_threshold=text_threshold,
93+
link_threshold=link_threshold,
94+
low_text=low_text,
95+
cuda=cuda,
96+
long_size=720,
97+
)
98+
99+
self.assertEqual(len(prediction_result["boxes"]), 35)
100+
self.assertEqual(len(prediction_result["boxes"][0]), 4)
101+
self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
102+
self.assertEqual(int(prediction_result["boxes"][0][0][0]), 111)
103+
self.assertEqual(len(prediction_result["polys"]), 35)
104+
self.assertEqual(
105+
prediction_result["heatmaps"]["text_score_heatmap"].shape, (240, 368, 3)
106+
)
107+
76108

77109
if __name__ == "__main__":
78110
unittest.main()

0 commit comments

Comments
 (0)