-
Notifications
You must be signed in to change notification settings - Fork 29
/
gen_mask.py
293 lines (232 loc) · 10.3 KB
/
gen_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#!/usr/bin/env python
# encoding: utf-8
'''
Generate image mask by trained model.
Tasks: Input an image file and output a mask image file.
@author: Cheng-Lin Li a.k.a. Clark
@copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved.
@license: Licensed under the Apache License v2.0. http://www.apache.org/licenses/
@contact: [email protected]
Tasks:
The program implementation will classify input image by a trained model and generate mask image as
image segmentation results.
Data:
Currently focus on person category data.
Reference:
https://github.com/jrosebr1/imutils/blob/master/imutils/video/webcamvideostream.py
'''
from threading import Thread
import argparse
import logging
from os.path import join
import numpy as np
from skimage import measure, filters
import scipy.ndimage.morphology
from utils.model_helper import create_model
# from data_helper import *
from utils.load_2D_data import generate_test_image
from utils.custom_data_aug import image_resize2square
# from test import threshold_mask
from datetime import datetime
import cv2
FILE_MIDDLE_NAME = 'train'
IMAGE_FOLDER = 'imgs'
MASK_FOLDER = 'masks'
RESOLUTION = 512 # Resolution of the input for the model.
ARGS = None
NET_INPUT = None
class FPS:
'''
Calculate Frame per Second
'''
def __init__(self):
# store the start time, end time, and total number of frames
# that were examined between the start and end intervals
self._start = None
self._end = None
self._numFrames = 0
def start(self):
# start the timer
self._start = datetime.now()
return self
def stop(self):
# stop the timer
self._end = datetime.now()
def update(self):
# increment the total number of frames examined during the
# start and end intervals
self._numFrames += 1
def elapsed(self):
# return the total number of seconds between the start and
# end interval
return (self._end - self._start).total_seconds()
def fps(self):
# compute the (approximate) frames per second
return self._numFrames / self.elapsed()
class WebcamVideoStream:
'''
Leverage thread to read video stream to speed up process time.
'''
def __init__(self, src=0):
# initialize the video camera stream and read the first frame
# from the stream
self.stream = cv2.VideoCapture(src)
(self.grabbed, self.frame) = self.stream.read()
# initialize the variable used to indicate if the thread should
# be stopped
self.stopped = False
def start(self):
# start the thread to read frames from the video stream
t = Thread(target=self.update, args=())
t.daemon = True
t.start()
return self
def update(self):
# keep looping infinitely until the thread is stopped
while True:
# if the thread indicator variable is set, stop the thread
if self.stopped:
return
# otherwise, read the next frame from the stream
(self.grabbed, self.frame) = self.stream.read()
def read(self):
# return the frame most recently read
return self.frame
def stop(self):
# indicate that the thread should be stopped
self.stopped = True
def threshold_mask(raw_output, threshold): #raw_output 3d:(119, 512, 512)
if threshold == 0:
try:
threshold = filters.threshold_otsu(raw_output)
except:
threshold = 0.5
logging.info('\tThreshold: {}'.format(threshold))
raw_output[raw_output > threshold] = 1
raw_output[raw_output < 1] = 0
#all_labels 3d:(119, 512, 512)
all_labels = measure.label(raw_output)
# props 3d: region of props=>list(_RegionProperties:<skimage.measure._regionprops._RegionProperties object>)
# with bbox.
props = measure.regionprops(all_labels)
props.sort(key=lambda x: x.area, reverse=True)
thresholded_mask = np.zeros(raw_output.shape)
if len(props) >= 2:
# if the largest is way larger than the second largest
if props[0].area / props[1].area > 5:
thresholded_mask[all_labels == props[0].label] = 1 # only turn on the largest component
else:
thresholded_mask[all_labels == props[0].label] = 1 # turn on two largest components
thresholded_mask[all_labels == props[1].label] = 1
elif len(props):
thresholded_mask[all_labels == props[0].label] = 1
# threshold_mask: 3d=(119, 512, 512)
thresholded_mask = scipy.ndimage.morphology.binary_fill_holes(thresholded_mask).astype(np.uint8)
return thresholded_mask
def apply_mask(image, mask):
"""apply mask to image"""
redImg = np.zeros(image.shape, image.dtype)
redImg[:,:] = (0,0,255)
redMask = cv2.bitwise_and(redImg, redImg, mask=mask)
cv2.addWeighted(redMask, 1, image, 1, 0, image)
return image
class segmentation_model():
'''
Model construction class for prediction
'''
def __init__(self, args, net_input_shape):
'''
Create evaluation model and load the pre-train weights for inference.
'''
self.net_input_shape = net_input_shape
weights_path = join(args.weights_path)
# Create model object in inference mode but Disable decoder layer.
_, eval_model, _ = create_model(args, net_input_shape, enable_decoder = False)
# Load weights trained on MS-COCO by name because part of output layers are disable.
eval_model.load_weights(weights_path, by_name=True)
self.model = eval_model
def detect(self, img_list, verbose = False):
result = []
r = dict()
for img_data in img_list:
output_array = self.model.predict_generator(generate_test_image(img_data,
self.net_input_shape,
batchSize=1,
numSlices=1,
subSampAmt=0,
stride=1),
steps=1, max_queue_size=1, workers=4,
use_multiprocessing=False, verbose=1)
output = output_array[:,:,:,0]
threshold_level = 0
output_bin = threshold_mask(output, threshold_level)
r['masks'] = output_bin[0,:,:]
# If you want to test the masking without prediction, mark out above line and unmark below line.
# Below line is make a dummy masking to test the speed.
# r['masks'] = np.ones((512, 512), np.int8) # Testing
result.append(r)
return result
if __name__ == '__main__':
'''
Main program for images segmentation by mask image.
Example command:
$python3 gen_mask --input_file ../data/image/train1.png --net segcapsr3 --model_weight ../data/saved_models/segcapsr3/dice16-255.hdf5
'''
parser = argparse.ArgumentParser(description = 'Mask image by segmentation algorithm')
parser.add_argument('--net', type = str.lower, default = 'segcapsr3',
choices = ['segcapsr3', 'segcapsr1', 'capsbasic', 'unet', 'tiramisu'],
help = 'Choose your network.')
parser.add_argument('--weights_path', type = str, required = True,
help = '/path/to/trained_model.hdf5 from root. Set to "" for none.')
parser.add_argument('--num_class', type = int, default = 2,
help = 'Number of classes to segment. Default is 2. If number of classes > 2, '
' the loss function will be softmax entropy and only apply on SegCapsR3'
'** Current version only support binary classification tasks.')
parser.add_argument('--which_gpus', type = str, default = '0',
help='Enter "-2" for CPU only, "-1" for all GPUs available, '
'or a comma separated list of GPU id numbers ex: "0,1,4".')
parser.add_argument('--gpus', type = int, default = -1,
help = 'Number of GPUs you have available for training. '
'If entering specific GPU ids under the --which_gpus arg or if using CPU, '
'then this number will be inferred, else this argument must be included.')
args = parser.parse_args()
net_input_shape = (RESOLUTION, RESOLUTION, 1)
model = segmentation_model(args, net_input_shape)
# # grab a pointer to the video stream and initialize the FPS counter
# print('[INFO] sampling frames from webcam...')
# cap = cv2.VideoCapture(0)
# these 3 lines can control fps, frame width and height.
# cap.set(cv2.CAP_PROP_FRAME_WIDTH, RESOLUTION)
# cap.set(cv2.CAP_PROP_FRAME_HEIGHT, RESOLUTION)
# cap.set(cv2.CAP_PROP_FPS, 0.1)
# fps = FPS().start()
# created a *threaded* video stream, allow the camera sensor to warmup,
# and start the FPS counter
print("[INFO] sampling THREADED frames from webcam...")
vs = WebcamVideoStream(src=0).start()
fps = FPS().start()
# loop over some frames
while fps._numFrames < 10000:
# grab the frame from the capture stream and resize it to have a maximum
# (grabbed, frame) = cap.read()
frame = vs.read()
frame = image_resize2square(frame, RESOLUTION) # frame = (512, 512, 3)
# check to see if the frame should be displayed to our screen
results = model.detect([frame], verbose=0)
r = results[0] #r['masks'] = [512, 512]
frame = apply_mask(frame, r['masks'])
cv2.imshow("Frame", frame)
# Press q or ESC to stop the video
if cv2.waitKey(1) & 0xFF == ord('q') or cv2.waitKey(1) == 27:
break
else:
pass
# update the FPS counter
fps.update()
# stop the timer and display FPS information
fps.stop()
print("[INFO] elasped time: {:.2f}".format(fps.elapsed()))
print("[INFO] approx. FPS: {:.2f}".format(fps.fps()))
# do a bit of cleanup
vs.release()
cv2.destroyAllWindows()