1
+ import cv2
2
+ import os
3
+ import json
4
+ import matplotlib .pyplot as plt
5
+ import numpy as np
6
+
7
+ from PIL import Image
8
+
9
+ class Card ():
10
+ def __init__ (self ,
11
+ root_dir : str = None ,
12
+ annotation_path : str = None ,
13
+ group_id : int = 0 ):
14
+ assert root_dir is not None
15
+ assert annotation_path is not None
16
+
17
+ self .root_dir = root_dir
18
+ self .annotation_path = annotation_path
19
+ self .group_id = group_id
20
+ self .top_left = list (np .zeros (2 ))
21
+ self .top_right = list (np .zeros (2 ))
22
+ self .bottom_right = list (np .zeros (2 ))
23
+ self .bottom_left = list (np .zeros (2 ))
24
+ self .bbox = list (np .zeros (4 ))
25
+
26
+ with open (annotation_path ) as f :
27
+ annotation = json .load (f )
28
+ shapes = annotation ["shapes" ]
29
+ corners = ["top_left" , "top_right" , "bottom_left" , "bottom_right" ]
30
+ for polygon in shapes :
31
+ if polygon ["group_id" ] is None :
32
+ polygon ["group_id" ] = 0
33
+ if int (polygon ["group_id" ]) == group_id :
34
+ label = polygon ["label" ]
35
+ if label in corners :
36
+ setattr (self , label , polygon ["points" ][0 ])
37
+ elif label == "card" :
38
+ self .x1y1 = polygon ["points" ][0 ]
39
+ self .x2y2 = polygon ["points" ][1 ]
40
+ self .bbox = self .x1y1 + self .x2y2 # bounding box is x1x2y1y2 format
41
+ else :
42
+ pass
43
+ else :
44
+ pass
45
+ self .center = [(self .top_left [0 ]+ self .bottom_right [0 ])/ 2.0 ,
46
+ (self .top_left [1 ]+ self .bottom_right [1 ])/ 2.0 ]
47
+ self .keypoints
48
+ self .img_path = os .path .join (self .root_dir , annotation ["imagePath" ])
49
+ self .image = self .load (self .img_path )
50
+ self .cropped_image = self .image .crop (tuple (self .bbox ))
51
+
52
+ def load (self , img_path : str = None ):
53
+ img = Image .open (img_path )
54
+ if img .mode != "RGB" :
55
+ img = img .convert ('RGB' )
56
+ return img
57
+
58
+ def transform (self ):
59
+
60
+ pass
61
+
62
+ def visualize (self ,
63
+ keypoints : bool = False ,
64
+ bbox : bool = False ):
65
+ if (keypoints is False ) and (bbox is False ):
66
+ print ("Nothing to show!" )
67
+ print ("Set the argument 'keypoints' or 'bbox' is True to visualize the sample" )
68
+ image = cv2 .imread (self .img_path )
69
+ green_bgr = (0 , 255 , 0 )
70
+ blue_bgr = (255 , 0 , 0 )
71
+ red_bgr = (0 , 0 , 255 )
72
+ yellow_bgr = (0 , 255 , 255 )
73
+ pink_bgr = (204 , 0 , 204 )
74
+ if bbox :
75
+ start_point = tuple (map (int , self .x1y1 ))
76
+ end_point = tuple (map (int , self .x2y2 ))
77
+ thickness = 2
78
+ cv2 .rectangle (image , start_point , end_point , green_bgr , thickness )
79
+ if keypoints :
80
+ points = [self .top_left , self .top_right , self .bottom_right , self .bottom_left , self .center ]
81
+ colors = [green_bgr , blue_bgr , red_bgr , yellow_bgr , pink_bgr ]
82
+ for i in range (5 ):
83
+ point = tuple (map (int , points [i ]))
84
+ radius = 3
85
+ thickness = 10
86
+ color = colors [i ]
87
+ image = cv2 .circle (image , point , radius , color , thickness )
88
+ plt .imshow (image [:,:,::- 1 ])
89
+ plt .show ()
90
+
91
+
92
+ if __name__ == "__main__" :
93
+ root_dir = "data/cccd_kpts"
94
+ annotation_path = "data/cccd_kpts/5.json"
95
+ card1 = Card (root_dir = root_dir ,
96
+ annotation_path = annotation_path ,
97
+ group_id = 0 )
98
+ card1 .visualize (bbox = True , keypoints = True )
99
+ # print(card1.top_left, card1.top_right)
0 commit comments