1
1
from __future__ import absolute_import
2
2
3
+ import os
4
+
3
5
import craft_text_detector .craft_utils as craft_utils
4
6
import craft_text_detector .file_utils as file_utils
5
7
import craft_text_detector .image_utils as image_utils
6
8
import craft_text_detector .predict as predict
7
9
import craft_text_detector .torch_utils as torch_utils
8
10
9
- __version__ = "0.3.3 "
11
+ __version__ = "0.3.4 "
10
12
11
13
12
14
__all__ = [
@@ -45,7 +47,6 @@ def __init__(
45
47
):
46
48
"""
47
49
Arguments:
48
- image_path: path to the image to be processed
49
50
output_dir: path to the results to be exported
50
51
rectify: rectify detected polygon by affine transform
51
52
export_extra: export heatmap, detection points, box visualization
@@ -102,21 +103,26 @@ def unload_refinenet_model(self):
102
103
self .refine_net = None
103
104
empty_cuda_cache ()
104
105
105
- def detect_text (self , image_path ):
106
+ def detect_text (self , image , image_path = None ):
106
107
"""
107
108
Arguments:
108
- image_path: path to the image to be processed
109
+ image: path to the image to be processed or numpy array or PIL image
110
+
109
111
Output:
110
- {"masks": lists of predicted masks 2d as bool array,
111
- "boxes": list of coords of points of predicted boxes,
112
- "boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size,
113
- "polys_as_ratios": list of coords of points of predicted polys as ratios of image size,
114
- "heatmaps": visualization of the detected characters/links,
115
- "text_crop_paths": list of paths of the exported text boxes/polys,
116
- "times": elapsed times of the sub modules, in seconds}
112
+ {
113
+ "masks": lists of predicted masks 2d as bool array,
114
+ "boxes": list of coords of points of predicted boxes,
115
+ "boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size,
116
+ "polys_as_ratios": list of coords of points of predicted polys as ratios of image size,
117
+ "heatmaps": visualization of the detected characters/links,
118
+ "text_crop_paths": list of paths of the exported text boxes/polys,
119
+ "times": elapsed times of the sub modules, in seconds
120
+ }
117
121
"""
118
- # load image
119
- image = read_image (image_path )
122
+
123
+ if image_path is not None :
124
+ print ("Argument 'image_path' is deprecated, use 'image' instead." )
125
+ image = image_path
120
126
121
127
# perform prediction
122
128
prediction_result = get_prediction (
@@ -142,10 +148,14 @@ def detect_text(self, image_path):
142
148
prediction_result ["text_crop_paths" ] = []
143
149
if self .output_dir is not None :
144
150
# export detected text regions
151
+ if type (image ) == str :
152
+ file_name , file_ext = os .path .splitext (os .path .basename (image ))
153
+ else :
154
+ file_name = "image"
145
155
exported_file_paths = export_detected_regions (
146
- image_path = image_path ,
147
156
image = image ,
148
157
regions = regions ,
158
+ file_name = file_name ,
149
159
output_dir = self .output_dir ,
150
160
rectify = self .rectify ,
151
161
)
@@ -154,10 +164,10 @@ def detect_text(self, image_path):
154
164
# export heatmap, detection points, box visualization
155
165
if self .export_extra :
156
166
export_extra_results (
157
- image_path = image_path ,
158
167
image = image ,
159
168
regions = regions ,
160
169
heatmaps = prediction_result ["heatmaps" ],
170
+ file_name = file_name ,
161
171
output_dir = self .output_dir ,
162
172
)
163
173
0 commit comments