Skip to content

Commit fb7b4ec

Browse files
committed
add Card class
1 parent 0694be7 commit fb7b4ec

20 files changed

+1374
-31
lines changed

.gitignore

+1-2
Original file line numberDiff line numberDiff line change
@@ -132,5 +132,4 @@ dmypy.json
132132
*png
133133
*ipynb
134134
*ocr_results*
135-
cccd*
136-
back*
135+
*zip

data/cards/1.json

+74
Large diffs are not rendered by default.

data/cards/10.json

+74
Large diffs are not rendered by default.

data/cards/11.json

+74
Large diffs are not rendered by default.

data/cards/12.json

+74
Large diffs are not rendered by default.

data/cards/13.json

+74
Large diffs are not rendered by default.

data/cards/14.json

+138
Large diffs are not rendered by default.

data/cards/15.json

+74
Large diffs are not rendered by default.

data/cards/16.json

+74
Large diffs are not rendered by default.

data/cards/17.json

+74
Large diffs are not rendered by default.

data/cards/3.json

+74
Large diffs are not rendered by default.

data/cards/4.json

+90
Large diffs are not rendered by default.

data/cards/5.json

+74
Large diffs are not rendered by default.

data/cards/6.json

+74
Large diffs are not rendered by default.

data/cards/7.json

+74
Large diffs are not rendered by default.

data/cards/8.json

+74
Large diffs are not rendered by default.

data/cards/9.json

+74
Large diffs are not rendered by default.

pick_preprocessing.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
import ssl
2-
ssl._create_default_https_context = ssl._create_unverified_context
3-
41
import os
52
import pandas as pd
63

74
from utils.ocr_utils import get_ocr_results
85

96
if __name__ == "__main__":
10-
print("Done")
11-
pass
12-
7+
result = get_ocr_results("data/cards/1.jpg")
8+
print(result)

utils/aumentation.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import os
2+
import random
3+
4+
import PIL
5+
from PIL import Image
6+
from utils.card import Card
7+

utils/card.py

+100-23
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@
77
from PIL import Image
88

99
class Card():
10+
"""
11+
Annotation must be created by labelme: https://github.com/wkentaro/labelme
12+
Structure of data folders:
13+
root_dir_name
14+
--abcxyz.jpg (or any image extension such as png, jpeg)
15+
--abcxyz.json
16+
17+
Required arguments:
18+
- root_dir (str): path to data folder
19+
- annotation_path (str): path to annotation file (.json)
20+
"""
1021
def __init__(self,
1122
root_dir: str = None,
1223
annotation_path: str = None,
@@ -26,7 +37,7 @@ def __init__(self,
2637
with open(annotation_path) as f:
2738
annotation = json.load(f)
2839
shapes = annotation["shapes"]
29-
corners = ["top_left", "top_right", "bottom_left", "bottom_right"]
40+
corners = ["top_left", "top_right", "bottom_right", "bottom_left"]
3041
for polygon in shapes:
3142
if polygon["group_id"] is None:
3243
polygon["group_id"] = 0
@@ -44,56 +55,122 @@ def __init__(self,
4455
pass
4556
self.center = [(self.top_left[0]+self.bottom_right[0])/2.0,
4657
(self.top_left[1]+self.bottom_right[1])/2.0]
47-
self.keypoints
48-
self.img_path = os.path.join(self.root_dir, annotation["imagePath"])
49-
self.image = self.load(self.img_path)
58+
self.skeleton = [[0,1], [1,2], [2,3], [3,0], [0,4], [1,4], [2,4], [3,4]]
59+
self.keypoints = [tuple(self.top_left),
60+
tuple(self.top_right),
61+
tuple(self.bottom_right),
62+
tuple(self.bottom_left),
63+
tuple(self.center)]
64+
65+
extension = annotation["imagePath"].split('.')[1]
66+
if extension != "jpg":
67+
annotation["imagePath"] = annotation["imagePath"].replace(extension, "jpg")
68+
self.image_path = os.path.join(self.root_dir, annotation["imagePath"])
69+
self.image_name = annotation["imagePath"]
70+
self.image = self.load(self.image_path)
5071
self.cropped_image = self.image.crop(tuple(self.bbox))
5172

52-
def load(self, img_path: str = None):
53-
img = Image.open(img_path)
73+
def load(self, image_path: str = None):
74+
"""
75+
Load image with PIL, auto convert to RGB mode
76+
"""
77+
img = Image.open(image_path)
5478
if img.mode != "RGB":
5579
img = img.convert('RGB')
5680
return img
5781

58-
def transform(self):
59-
60-
pass
82+
def convert_to_opencv(self):
83+
open_cv_image = np.array(self.image)
84+
image = open_cv_image[:, :, ::-1].copy() # RGB to BGR
85+
return image
6186

6287
def visualize(self,
63-
keypoints: bool = False,
64-
bbox: bool = False):
65-
if (keypoints is False) and (bbox is False):
66-
print("Nothing to show!")
67-
print("Set the argument 'keypoints' or 'bbox' is True to visualize the sample")
68-
image = cv2.imread(self.img_path)
88+
bbox: bool = True,
89+
keypoints: bool = True,
90+
skeleton: bool = False):
91+
"""
92+
Draw annotation and show the sample image
93+
94+
Arguments:
95+
- keypoints (bool): draw and show keypoints
96+
- box (bool): draw and show bounding box
97+
"""
98+
image = self.convert_to_opencv()
6999
green_bgr = (0, 255, 0)
70100
blue_bgr = (255, 0, 0)
71101
red_bgr = (0, 0, 255)
72102
yellow_bgr = (0, 255, 255)
73103
pink_bgr = (204, 0, 204)
104+
points = self.keypoints
74105
if bbox:
75106
start_point = tuple(map(int, self.x1y1))
76107
end_point = tuple(map(int, self.x2y2))
77108
thickness = 2
78109
cv2.rectangle(image, start_point, end_point, green_bgr, thickness)
79110
if keypoints:
80-
points = [self.top_left, self.top_right, self.bottom_right, self.bottom_left, self.center]
81111
colors = [green_bgr, blue_bgr, red_bgr, yellow_bgr, pink_bgr]
82112
for i in range(5):
83113
point = tuple(map(int, points[i]))
84-
radius = 3
85-
thickness = 10
114+
radius = 7
115+
thickness = 20
86116
color = colors[i]
87117
image = cv2.circle(image, point, radius, color, thickness)
118+
if skeleton:
119+
for joint in self.skeleton:
120+
start_point = tuple(map(int, points[joint[0]]))
121+
end_point = tuple(map(int, points[joint[1]]))
122+
navy_bgr = (128, 0, 0)
123+
thickness = 3
124+
cv2.line(image, start_point, end_point, navy_bgr, thickness)
125+
88126
plt.imshow(image[:,:,::-1])
89127
plt.show()
128+
129+
def augment(self,
130+
background_dir: str = "./data/background",
131+
max_card_width: float = 1000.0,
132+
max_image_width: float = 2000.0,
133+
angles: list = None,
134+
save_image: bool = False):
135+
if not os.path.exists(background_dir):
136+
return self.image
137+
if save_image:
138+
save_dir = "data/augmented_images"
139+
if not os.path.exists((save_dir)):
140+
os.makedirs(save_dir)
141+
if angles is None:
142+
angles = [i*10 for i in range(-1,1)]
143+
self.augmented_images = []
144+
145+
cropped_image = self.cropped_image
146+
w, h = cropped_image.size
147+
scale = max_card_width / w
148+
cropped_image = cropped_image.resize((int(max_card_width), int(h*scale)))
149+
150+
for file_name in os.listdir(background_dir):
151+
for angle in angles:
152+
background = self.load(os.path.join(background_dir, file_name))
153+
w, h = background.size
154+
scale = max_image_width / w
155+
background = background.resize((int(max_image_width), int(h*scale)))
90156

157+
# Rotate the card
158+
mask = Image.new('L', cropped_image.size, 255)
159+
front = cropped_image.rotate(angle, expand=True)
160+
# Paste the rotated card on background
161+
mask = mask.rotate(angle, expand=True)
162+
background.paste(front, (400, 200), mask)
163+
self.augmented_images.append(background)
164+
if save_image:
165+
saved_path = os.path.join(save_dir, file_name.split('.')[0]+'_'+str(angle)+'_'+self.image_name)
166+
background.save(saved_path)
91167

92168
if __name__ == "__main__":
93-
root_dir = "data/cccd_kpts"
94-
annotation_path = "data/cccd_kpts/5.json"
95-
card1 = Card(root_dir=root_dir,
169+
root_dir = "data/cards"
170+
annotation_path = "data/cards/16.json"
171+
card = Card(root_dir=root_dir,
96172
annotation_path=annotation_path,
97173
group_id=0)
98-
card1.visualize(bbox=True, keypoints=True)
99-
# print(card1.top_left, card1.top_right)
174+
# print(card.visualize(skeleton=True))
175+
card.augment(save_image=False)
176+
print(len(card.augmented_images))

0 commit comments

Comments
 (0)