-
Notifications
You must be signed in to change notification settings - Fork 8
/
segmentor.py
executable file
·87 lines (61 loc) · 2.32 KB
/
segmentor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#
# Semantic Segmentator
# Class for driving image segmentation
# Integrated in drive.py. For more
# information see README.md.
# ------------------------------------
# Neil Nie & Michael Meng
# (c) Yongyang Nie 2018
#
import models.enet_naive_upsampling.model as enet
import utils as utils
import configs as configs
import cv2
import numpy as np
class Segmentor:
def __init__(self, type):
self.model = enet.build(len(utils.labels), configs.img_height, configs.img_width)
self.model.load_weights(configs.infer_model_path)
self.backgrounds = self.load_color_backgrounds()
@staticmethod
def load_color_backgrounds():
backgrounds = []
for i in range(len(utils.labels)):
color = utils.labels[i][7]
bg = np.zeros((configs.img_height, configs.img_width, 3), dtype=np.uint8)
bg[:, :, 0].fill(color[0])
bg[:, :, 1].fill(color[1])
bg[:, :, 2].fill(color[2])
backgrounds.append(bg)
return backgrounds
def semantic_segmentation(self, image, visualize=False):
# parameters
# image: input image
# visualize: whether to visualize the segmentation results
# return
# output: output of ConvNet
# img_pred: visualization
image = cv2.resize(image, (640,360))
output = self.model.predict(np.array([image]))[0]
if visualize:
im_mask = self.convert_class_to_rgb(output)
else:
im_mask = image
img_pred = cv2.addWeighted(im_mask, 0.8, image, 0.8, 0)
img_pred = cv2.resize(img_pred, (640, 480))
return output, img_pred
def convert_class_to_rgb(self, image_labels, threshold=0.05):
# convert any pixel > threshold to 1
# convert any pixel < threshold to 0
# then use bitwise_and
output = np.zeros((configs.img_height, configs.img_width, 3), dtype=np.uint8)
for i in range(len(utils.labels)):
split = image_labels[:, :, i]
split[split > threshold] = 1
split[split < threshold] = 0
split[:] *= 255
split = split.astype(np.uint8)
bg = self.backgrounds[i].copy()
res = cv2.bitwise_and(bg, bg, mask=split)
output = cv2.addWeighted(output, 1.0, res, 1.0, 0)
return output