7
7
from PIL import Image
8
8
9
9
class Card ():
10
+ """
11
+ Annotation must be created by labelme: https://github.com/wkentaro/labelme
12
+ Structure of data folders:
13
+ root_dir_name
14
+ --abcxyz.jpg (or any image extension such as png, jpeg)
15
+ --abcxyz.json
16
+
17
+ Required arguments:
18
+ - root_dir (str): path to data folder
19
+ - annotation_path (str): path to annotation file (.json)
20
+ """
10
21
def __init__ (self ,
11
22
root_dir : str = None ,
12
23
annotation_path : str = None ,
@@ -26,7 +37,7 @@ def __init__(self,
26
37
with open (annotation_path ) as f :
27
38
annotation = json .load (f )
28
39
shapes = annotation ["shapes" ]
29
- corners = ["top_left" , "top_right" , "bottom_left " , "bottom_right " ]
40
+ corners = ["top_left" , "top_right" , "bottom_right " , "bottom_left " ]
30
41
for polygon in shapes :
31
42
if polygon ["group_id" ] is None :
32
43
polygon ["group_id" ] = 0
@@ -44,56 +55,122 @@ def __init__(self,
44
55
pass
45
56
self .center = [(self .top_left [0 ]+ self .bottom_right [0 ])/ 2.0 ,
46
57
(self .top_left [1 ]+ self .bottom_right [1 ])/ 2.0 ]
47
- self .keypoints
48
- self .img_path = os .path .join (self .root_dir , annotation ["imagePath" ])
49
- self .image = self .load (self .img_path )
58
+ self .skeleton = [[0 ,1 ], [1 ,2 ], [2 ,3 ], [3 ,0 ], [0 ,4 ], [1 ,4 ], [2 ,4 ], [3 ,4 ]]
59
+ self .keypoints = [tuple (self .top_left ),
60
+ tuple (self .top_right ),
61
+ tuple (self .bottom_right ),
62
+ tuple (self .bottom_left ),
63
+ tuple (self .center )]
64
+
65
+ extension = annotation ["imagePath" ].split ('.' )[1 ]
66
+ if extension != "jpg" :
67
+ annotation ["imagePath" ] = annotation ["imagePath" ].replace (extension , "jpg" )
68
+ self .image_path = os .path .join (self .root_dir , annotation ["imagePath" ])
69
+ self .image_name = annotation ["imagePath" ]
70
+ self .image = self .load (self .image_path )
50
71
self .cropped_image = self .image .crop (tuple (self .bbox ))
51
72
52
- def load (self , img_path : str = None ):
53
- img = Image .open (img_path )
73
+ def load (self , image_path : str = None ):
74
+ """
75
+ Load image with PIL, auto convert to RGB mode
76
+ """
77
+ img = Image .open (image_path )
54
78
if img .mode != "RGB" :
55
79
img = img .convert ('RGB' )
56
80
return img
57
81
58
- def transform (self ):
59
-
60
- pass
82
+ def convert_to_opencv (self ):
83
+ open_cv_image = np .array (self .image )
84
+ image = open_cv_image [:, :, ::- 1 ].copy () # RGB to BGR
85
+ return image
61
86
62
87
def visualize (self ,
63
- keypoints : bool = False ,
64
- bbox : bool = False ):
65
- if (keypoints is False ) and (bbox is False ):
66
- print ("Nothing to show!" )
67
- print ("Set the argument 'keypoints' or 'bbox' is True to visualize the sample" )
68
- image = cv2 .imread (self .img_path )
88
+ bbox : bool = True ,
89
+ keypoints : bool = True ,
90
+ skeleton : bool = False ):
91
+ """
92
+ Draw annotation and show the sample image
93
+
94
+ Arguments:
95
+ - keypoints (bool): draw and show keypoints
96
+ - box (bool): draw and show bounding box
97
+ """
98
+ image = self .convert_to_opencv ()
69
99
green_bgr = (0 , 255 , 0 )
70
100
blue_bgr = (255 , 0 , 0 )
71
101
red_bgr = (0 , 0 , 255 )
72
102
yellow_bgr = (0 , 255 , 255 )
73
103
pink_bgr = (204 , 0 , 204 )
104
+ points = self .keypoints
74
105
if bbox :
75
106
start_point = tuple (map (int , self .x1y1 ))
76
107
end_point = tuple (map (int , self .x2y2 ))
77
108
thickness = 2
78
109
cv2 .rectangle (image , start_point , end_point , green_bgr , thickness )
79
110
if keypoints :
80
- points = [self .top_left , self .top_right , self .bottom_right , self .bottom_left , self .center ]
81
111
colors = [green_bgr , blue_bgr , red_bgr , yellow_bgr , pink_bgr ]
82
112
for i in range (5 ):
83
113
point = tuple (map (int , points [i ]))
84
- radius = 3
85
- thickness = 10
114
+ radius = 7
115
+ thickness = 20
86
116
color = colors [i ]
87
117
image = cv2 .circle (image , point , radius , color , thickness )
118
+ if skeleton :
119
+ for joint in self .skeleton :
120
+ start_point = tuple (map (int , points [joint [0 ]]))
121
+ end_point = tuple (map (int , points [joint [1 ]]))
122
+ navy_bgr = (128 , 0 , 0 )
123
+ thickness = 3
124
+ cv2 .line (image , start_point , end_point , navy_bgr , thickness )
125
+
88
126
plt .imshow (image [:,:,::- 1 ])
89
127
plt .show ()
128
+
129
+ def augment (self ,
130
+ background_dir : str = "./data/background" ,
131
+ max_card_width : float = 1000.0 ,
132
+ max_image_width : float = 2000.0 ,
133
+ angles : list = None ,
134
+ save_image : bool = False ):
135
+ if not os .path .exists (background_dir ):
136
+ return self .image
137
+ if save_image :
138
+ save_dir = "data/augmented_images"
139
+ if not os .path .exists ((save_dir )):
140
+ os .makedirs (save_dir )
141
+ if angles is None :
142
+ angles = [i * 10 for i in range (- 1 ,1 )]
143
+ self .augmented_images = []
144
+
145
+ cropped_image = self .cropped_image
146
+ w , h = cropped_image .size
147
+ scale = max_card_width / w
148
+ cropped_image = cropped_image .resize ((int (max_card_width ), int (h * scale )))
149
+
150
+ for file_name in os .listdir (background_dir ):
151
+ for angle in angles :
152
+ background = self .load (os .path .join (background_dir , file_name ))
153
+ w , h = background .size
154
+ scale = max_image_width / w
155
+ background = background .resize ((int (max_image_width ), int (h * scale )))
90
156
157
+ # Rotate the card
158
+ mask = Image .new ('L' , cropped_image .size , 255 )
159
+ front = cropped_image .rotate (angle , expand = True )
160
+ # Paste the rotated card on background
161
+ mask = mask .rotate (angle , expand = True )
162
+ background .paste (front , (400 , 200 ), mask )
163
+ self .augmented_images .append (background )
164
+ if save_image :
165
+ saved_path = os .path .join (save_dir , file_name .split ('.' )[0 ]+ '_' + str (angle )+ '_' + self .image_name )
166
+ background .save (saved_path )
91
167
92
168
if __name__ == "__main__" :
93
- root_dir = "data/cccd_kpts "
94
- annotation_path = "data/cccd_kpts/5 .json"
95
- card1 = Card (root_dir = root_dir ,
169
+ root_dir = "data/cards "
170
+ annotation_path = "data/cards/16 .json"
171
+ card = Card (root_dir = root_dir ,
96
172
annotation_path = annotation_path ,
97
173
group_id = 0 )
98
- card1 .visualize (bbox = True , keypoints = True )
99
- # print(card1.top_left, card1.top_right)
174
+ # print(card.visualize(skeleton=True))
175
+ card .augment (save_image = False )
176
+ print (len (card .augmented_images ))
0 commit comments