Skip to content

Commit cffa6ab

Browse files
committedOct 24, 2018
add_tfrecord_decode_and_model_inference_demo
Signed-off-by: robin <zhangbinatp@gmail.com>
1 parent 61deb73 commit cffa6ab

9 files changed

+1644
-13
lines changed
 

‎DeepLab_inference_Demo.py

+223
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import os
2+
from io import BytesIO
3+
import tarfile
4+
import tempfile
5+
from six.moves import urllib
6+
7+
from matplotlib import gridspec
8+
from matplotlib import pyplot as plt
9+
import numpy as np
10+
from PIL import Image
11+
import cv2
12+
13+
import tensorflow as tf
14+
import time
15+
from datetime import datetime
16+
17+
slim = tf.contrib.slim
18+
19+
20+
class DeepLabModel(object):
21+
"""Class to load deeplab model and run inference."""
22+
23+
INPUT_TENSOR_NAME = 'ImageTensor:0'
24+
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
25+
INPUT_SIZE = 513
26+
FROZEN_GRAPH_NAME = 'frozen_inference_graph.pb'
27+
28+
def __init__(self, modir_dir):
29+
"""Creates and loads pretrained deeplab model."""
30+
self.graph = tf.Graph()
31+
32+
graph_def = None
33+
with tf.gfile.GFile(os.path.join(modir_dir,self.FROZEN_GRAPH_NAME), "rb") as f:
34+
graph_def = tf.GraphDef()
35+
graph_def.ParseFromString(f.read())
36+
37+
if graph_def is None:
38+
raise RuntimeError('Cannot find inference graph in tar archive.')
39+
40+
with self.graph.as_default():
41+
tf.import_graph_def(graph_def, name='')
42+
43+
self.sess = tf.Session(graph=self.graph)
44+
45+
ops = self.sess.graph.get_operations()
46+
for op in ops:
47+
print(op.name)
48+
49+
50+
writer = tf.summary.FileWriter("./logs", graph=self.graph)
51+
writer.close()
52+
53+
def run(self, image):
54+
"""Runs inference on a single image.
55+
56+
Args:
57+
image: A PIL.Image object, raw input image.
58+
59+
Returns:
60+
resized_image: RGB image resized from original input image.
61+
seg_map: Segmentation map of `resized_image`.
62+
"""
63+
width, height = image.size
64+
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
65+
target_size = (int(resize_ratio * width), int(resize_ratio * height))
66+
print("origin_size:",image.size," target_size:",target_size)
67+
resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
68+
69+
start_time = time.time()
70+
batch_seg_map = self.sess.run(
71+
self.OUTPUT_TENSOR_NAME,
72+
feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
73+
duration = time.time() - start_time
74+
75+
print ('%s: , duration = %.3f s ' %(datetime.now(), duration))
76+
77+
seg_map = batch_seg_map[0]
78+
return resized_image, seg_map
79+
80+
81+
def create_pascal_label_colormap():
82+
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
83+
84+
Returns:
85+
A Colormap for visualizing segmentation results.
86+
"""
87+
colormap = np.zeros((256, 3), dtype=int)
88+
ind = np.arange(256, dtype=int)
89+
90+
for shift in reversed(range(8)):
91+
for channel in range(3):
92+
colormap[:, channel] |= ((ind >> channel) & 1) << shift
93+
ind >>= 3
94+
95+
return colormap
96+
97+
98+
def label_to_color_image(label):
99+
"""Adds color defined by the dataset colormap to the label.
100+
101+
Args:
102+
label: A 2D array with integer type, storing the segmentation label.
103+
104+
Returns:
105+
result: A 2D array with floating type. The element of the array
106+
is the color indexed by the corresponding element in the input label
107+
to the PASCAL color map.
108+
109+
Raises:
110+
ValueError: If label is not of rank 2 or its value is larger than color
111+
map maximum entry.
112+
"""
113+
if label.ndim != 2:
114+
raise ValueError('Expect 2-D input label')
115+
116+
colormap = create_pascal_label_colormap()
117+
118+
if np.max(label) >= len(colormap):
119+
raise ValueError('label value too large.')
120+
121+
return colormap[label]
122+
123+
124+
def vis_segmentation(image, seg_map):
125+
"""Visualizes input image, segmentation map and overlay view."""
126+
plt.figure(figsize=(15, 5))
127+
grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
128+
129+
plt.subplot(grid_spec[0])
130+
plt.imshow(image)
131+
plt.axis('off')
132+
plt.title('input image')
133+
134+
plt.subplot(grid_spec[1])
135+
seg_image = label_to_color_image(seg_map).astype(np.uint8)
136+
plt.imshow(seg_image)
137+
plt.axis('off')
138+
plt.title('segmentation map')
139+
140+
plt.subplot(grid_spec[2])
141+
plt.imshow(image)
142+
plt.imshow(seg_image, alpha=0.7)
143+
plt.axis('off')
144+
plt.title('segmentation overlay')
145+
146+
unique_labels = np.unique(seg_map)
147+
ax = plt.subplot(grid_spec[3])
148+
plt.imshow(
149+
FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
150+
ax.yaxis.tick_right()
151+
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
152+
plt.xticks([], [])
153+
ax.tick_params(width=0.0)
154+
plt.grid('off')
155+
plt.show()
156+
157+
def run_visualization(deeplab,image_dir):
158+
"""Inferences DeepLab model and visualizes result."""
159+
160+
image_files = tf.gfile.Glob(image_dir+"*.jpg")
161+
print(image_files)
162+
163+
for file in image_files:
164+
with tf.gfile.FastGFile(file) as f:
165+
original_im = Image.open(BytesIO(f.read()))
166+
167+
resized_im, seg_map = MODEL.run(original_im)
168+
169+
# vis_segmentation(resized_im, seg_map)
170+
171+
image_raw = cv2.imread(file)
172+
image_resize = cv2.resize(image_raw,resized_im.size)
173+
cv2.imshow('image_raw',image_resize)
174+
175+
colored_label = label_to_color_image(seg_map)
176+
colored_label = cv2.cvtColor(colored_label.astype(np.uint8),cv2.COLOR_RGB2BGR)
177+
cv2.imshow("colored_label",colored_label)
178+
179+
alpha = 0.4
180+
img_add = img_add = cv2.addWeighted(image_resize, alpha, colored_label, 1-alpha, 0)
181+
cv2.imshow("colored_overlap",img_add)
182+
cv2.waitKey(0)
183+
184+
185+
186+
MODEL_DIR= "/home/robin/Dataset/models/semantic_segmentation/deeplabv3_pascal_trainval_2018_01_04/deeplabv3_pascal_trainval"
187+
# MODEL_DIR= "/home/robin/Dataset/models/semantic_segmentation/deeplabv3_mnv2_pascal_train_aug_2018_01_29/deeplabv3_mnv2_pascal_train_aug"
188+
189+
flags = tf.app.flags
190+
191+
# Dataset settings.
192+
flags.DEFINE_string('dataset', 'ade20k',
193+
'Name of the segmentation dataset.')
194+
195+
tf.app.flags.DEFINE_string(
196+
'test_path', 'images_demo/', 'Test image path.')
197+
198+
flags.DEFINE_string('train_split', 'train',
199+
'Which split of the dataset to be used for training')
200+
201+
flags.DEFINE_string('modir_dir', MODEL_DIR, 'Where the Model reside.')
202+
203+
FLAGS = flags.FLAGS
204+
205+
LABEL_NAMES = np.asarray([
206+
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
207+
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
208+
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'
209+
])
210+
211+
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
212+
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
213+
214+
215+
if __name__ == "__main__":
216+
MODEL = DeepLabModel(FLAGS.modir_dir)
217+
print('model loaded successfully!')
218+
219+
run_visualization(MODEL,FLAGS.test_path)
220+
221+
222+
223+

‎core/feature_extractor_demo.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
2+
3+
"""Tests for xception.py."""
4+
import numpy as np
5+
import six
6+
import tensorflow as tf
7+
import time
8+
from datetime import datetime
9+
10+
import sys
11+
import os
12+
# This is needed since the notebook is stored in the object_detection folder.
13+
TF_API="/home/robin/eclipse-workspace-python/TF_models/models/research"
14+
sys.path.append(os.path.split(TF_API)[0])
15+
sys.path.append(TF_API)
16+
17+
from deeplab.core import feature_extractor
18+
19+
slim = tf.contrib.slim
20+
21+
22+
23+
24+
flags = tf.app.flags
25+
26+
# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
27+
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
28+
# one could use different atrous_rates/output_stride during training/evaluation.
29+
flags.DEFINE_multi_integer('atrous_rates', None,
30+
'Atrous rates for atrous spatial pyramid pooling.')
31+
flags.DEFINE_integer('output_stride', 8,
32+
'The ratio of input to output spatial resolution.')
33+
34+
# Defaults to None. Set multi_grid = [1, 2, 4] when using provided
35+
# 'resnet_v1_{50,101}_beta' checkpoints.
36+
flags.DEFINE_multi_integer('multi_grid', None,
37+
'Employ a hierarchy of atrous rates for ResNet.')
38+
39+
# When using 'mobilent_v2', we set atrous_rates = decoder_output_stride = None.
40+
# When using 'xception_65' or 'resnet_v1' model variants, we set
41+
# atrous_rates = [6, 12, 18] (output stride 16) and decoder_output_stride = 4.
42+
# See core/feature_extractor.py for supported model variants.
43+
flags.DEFINE_string('model_variant', 'xception_65', 'DeepLab model variant.')
44+
45+
flags.DEFINE_float('depth_multiplier', 1.0,
46+
'Multiplier for the depth (number of channels) for all '
47+
'convolution ops used in MobileNet.')
48+
49+
50+
51+
52+
FLAGS = flags.FLAGS
53+
54+
55+
if __name__ == '__main__':
56+
57+
images = tf.random_normal([1, 513, 513, 3])
58+
59+
60+
features, end_points = feature_extractor.extract_features(
61+
images,
62+
output_stride=FLAGS.output_stride,
63+
multi_grid=FLAGS.multi_grid,
64+
model_variant=FLAGS.model_variant,
65+
depth_multiplier=FLAGS.depth_multiplier,
66+
weight_decay=0.0001,
67+
reuse=None,
68+
is_training=False,
69+
fine_tune_batch_norm=False)
70+
71+
72+
print(features, end_points)
73+
74+
writer = tf.summary.FileWriter("./logs", graph=tf.get_default_graph())
75+
76+
77+
print("Layers")
78+
for k, v in end_points.items():
79+
print('name = {}, shape = {}'.format(v.name, v.get_shape()))
80+
81+
print("Parameters")
82+
for v in slim.get_model_variables():
83+
print('name = {}, shape = {}'.format(v.name, v.get_shape()))
84+
85+
86+
87+
88+
89+

‎core/resnet_v1_beta.py

+517
Large diffs are not rendered by default.

‎core/resnet_v1_beta_demo.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Tests for xception.py."""
17+
import numpy as np
18+
import six
19+
import tensorflow as tf
20+
import time
21+
from datetime import datetime
22+
23+
import sys
24+
import os
25+
# This is needed since the notebook is stored in the object_detection folder.
26+
TF_API="/home/robin/eclipse-workspace-python/TF_models/models/research"
27+
sys.path.append(os.path.split(TF_API)[0])
28+
sys.path.append(TF_API)
29+
30+
from deeplab.core import resnet_v1_beta
31+
from tensorflow.contrib.slim.nets import resnet_utils
32+
33+
slim = tf.contrib.slim
34+
35+
36+
37+
38+
39+
if __name__ == '__main__':
40+
inputs = tf.random_normal([1, 224, 224, 3])
41+
42+
43+
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
44+
45+
net, end_points = resnet_v1_beta.resnet_v1_101(inputs,
46+
num_classes=100,
47+
is_training=False,
48+
global_pool=True,
49+
output_stride=None,
50+
multi_grid=None,
51+
reuse=None,
52+
scope='resnet_v1_101')
53+
54+
writer = tf.summary.FileWriter("./logs", graph=tf.get_default_graph())
55+
56+
57+
print("Layers")
58+
for k, v in end_points.items():
59+
print('name = {}, shape = {}'.format(v.name, v.get_shape()))
60+
61+
# print("Parameters")
62+
# for v in slim.get_model_variables():
63+
# print('name = {}, shape = {}'.format(v.name, v.get_shape()))
64+
65+
66+
67+
init = tf.global_variables_initializer()
68+
with tf.Session() as sess:
69+
sess.run(init)
70+
71+
for i in range(10):
72+
start_time = time.time()
73+
_ = sess.run(net)
74+
duration = time.time() - start_time
75+
print ('%s: step %d, duration = %.3f' %(datetime.now(), i, duration))
76+
77+

‎core/xception_demo.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Tests for xception.py."""
17+
import numpy as np
18+
import six
19+
import tensorflow as tf
20+
import time
21+
from datetime import datetime
22+
23+
import sys
24+
import os
25+
# This is needed since the notebook is stored in the object_detection folder.
26+
TF_API="/home/robin/eclipse-workspace-python/TF_models/models/research"
27+
sys.path.append(os.path.split(TF_API)[0])
28+
sys.path.append(TF_API)
29+
30+
from deeplab.core import xception
31+
from tensorflow.contrib.slim.nets import resnet_utils
32+
33+
slim = tf.contrib.slim
34+
35+
36+
37+
38+
39+
if __name__ == '__main__':
40+
inputs = tf.random_normal([1, 224, 224, 3])
41+
42+
43+
with slim.arg_scope(xception.xception_arg_scope()):
44+
45+
net, end_points = xception.xception_65(inputs,
46+
num_classes=100,
47+
is_training=False,
48+
global_pool=True,
49+
keep_prob=0.5,
50+
output_stride=None,
51+
regularize_depthwise=False,
52+
multi_grid=[12,16,18],
53+
reuse=None,
54+
scope='xception_65')
55+
56+
writer = tf.summary.FileWriter("./logs", graph=tf.get_default_graph())
57+
58+
59+
print("Layers")
60+
for k, v in end_points.items():
61+
print('name = {}, shape = {}'.format(v.name, v.get_shape()))
62+
63+
# print("Parameters")
64+
# for v in slim.get_model_variables():
65+
# print('name = {}, shape = {}'.format(v.name, v.get_shape()))
66+
67+
68+
69+
init = tf.global_variables_initializer()
70+
with tf.Session() as sess:
71+
sess.run(init)
72+
73+
for i in range(10):
74+
start_time = time.time()
75+
_ = sess.run(net)
76+
duration = time.time() - start_time
77+
print ('%s: step %d, duration = %.3f' %(datetime.now(), i, duration))
78+
79+

‎dataset/segmentation_vis_demo.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@
3333
FLAGS = tf.app.flags.FLAGS
3434

3535
tf.app.flags.DEFINE_string('original_color_folder',
36-
'./battery_word_seg/JPEGImages',
36+
"/home/robin/Dataset/VOC/VOC2012_VOCtrainval/VOC2012/JPEGImages",
3737
'Original ground truth annotations.')
3838

39-
tf.app.flags.DEFINE_string('original_gt_folder',
40-
'./battery_word_seg/SegmentationClassRaw',
41-
'Original ground truth annotations.')
39+
tf.app.flags.DEFINE_string('semantic_segmentation_folder',
40+
"/home/robin/Dataset/VOC/VOC2012_VOCtrainval/VOC2012/SegmentationClassRaw",
41+
'Folder containing semantic segmentation annotations.')
4242

4343
tf.app.flags.DEFINE_string('segmentation_format', 'png', 'Segmentation format.')
4444

@@ -74,7 +74,7 @@ def _save_annotation(annotation, filename):
7474

7575
def vis_segmentation(image, seg_map):
7676
"""Visualizes input image, segmentation map and overlay view."""
77-
plt.figure(figsize=(15, 5))
77+
plt.figure(figsize=(16, 8))
7878
grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
7979

8080
plt.subplot(grid_spec[0])
@@ -104,27 +104,37 @@ def main(unused_argv):
104104

105105
if(FLAGS.convert):
106106

107-
annotations = glob.glob(os.path.join(FLAGS.original_gt_folder,
107+
annotations = glob.glob(os.path.join(FLAGS.semantic_segmentation_folder,
108108
'*.' + FLAGS.segmentation_format))
109109

110110

111111

112112

113113
for annotation in annotations:
114114
print(annotation)
115-
#filename = os.path.join(FLAGS.segmentation_output_dir,os.path.basename(annotation)[:-4]+".jpg")
115+
116116

117117
ori_filename = os.path.join(FLAGS.original_color_folder,os.path.basename(annotation)[:-4]+".jpg")
118118
print(ori_filename)
119119
# ori_im =Image.open(ori_filename)
120-
orignal_im = cv2.imread(ori_filename)
121-
image_RGB = cv2.cvtColor(orignal_im,cv2.COLOR_BGR2RGB)
120+
color_im = cv2.imread(ori_filename)
121+
rgb_image = cv2.cvtColor(color_im,cv2.COLOR_BGR2RGB)
122+
print(rgb_image.shape)
122123

123-
mask_im = cv2.imread(annotation)
124-
print(mask_im.shape)
125-
mask_RGB = cv2.cvtColor(mask_im,cv2.COLOR_BGR2RGB)
124+
seg_im = cv2.imread(annotation,0)
125+
print(seg_im.shape)
126+
127+
#dst = src1 * alpha + src2 * beta + gamma;
128+
#alpha,beta,gamma
129+
# alpha = 0.3
130+
# beta = 1-alpha
131+
# gamma = 0
132+
# img_add = cv2.addWeighted(rgb_image, alpha, seg_im, beta, gamma)
133+
# cv2.imshow("image_add",img_add)
134+
# cv2.waitKey(0)
135+
126136

127-
vis_segmentation(orignal_im,mask_RGB)
137+
vis_segmentation(rgb_image,seg_im*125)
128138

129139

130140
if __name__ == '__main__':

‎dataset/tfrecord_mask_demo.py

+255
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#coding=utf-8
2+
import tensorflow as tf
3+
import sys
4+
import cv2
5+
import os
6+
import matplotlib.pyplot as plt
7+
from PIL import Image
8+
import numpy as np
9+
import csv
10+
11+
# This is needed since the notebook is stored in the object_detection folder.
12+
TF_API="/home/robin/eclipse-workspace-python/TF_models/models/research"
13+
sys.path.append(os.path.split(TF_API)[0])
14+
sys.path.append(TF_API)
15+
16+
from object_detection.utils import visualization_utils as vis_util
17+
slim = tf.contrib.slim
18+
19+
20+
NUM_CLASSES = 20
21+
SPLITS_TO_SIZES = {
22+
'train': 5011,
23+
'test': 4952,
24+
}
25+
ITEMS_TO_DESCRIPTIONS = {
26+
'image': 'A color image of varying height and width.',
27+
'shape': 'Shape of the image',
28+
'object/bbox': 'A list of bounding boxes, one per each object.',
29+
'object/label': 'A list of labels, one per each object.',
30+
}
31+
32+
labels_to_class =['none','aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
33+
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
34+
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
35+
'train', 'tvmonitor']
36+
37+
FILE_PATTERN = 'voc_2007_%s_*.tfrecord'
38+
39+
def _get_output_filename(dataset_dir, split_name):
40+
"""Creates the output filename.
41+
Args:
42+
dataset_dir: The dataset directory where the dataset is stored.
43+
split_name: The name of the train/test split.
44+
Returns:
45+
An absolute file path.
46+
"""
47+
return '%s/%s*.tfrecord' % (dataset_dir, split_name)
48+
49+
def bboxes_draw_on_img(img, classes, bboxes, colors, thickness=2):
50+
shape = img.shape
51+
for i in range(bboxes.shape[0]):
52+
bbox = bboxes[i]
53+
# Draw bounding box...
54+
p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
55+
p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
56+
cv2.rectangle(img, p1[::-1], p2[::-1], colors, thickness)
57+
# Draw text...
58+
s = '%s' % (labels_to_class[classes[i]])
59+
p1 = (p1[0]+15, p1[1]+5)
60+
cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.4, colors, 1)
61+
62+
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
63+
"""Gets a dataset tuple with instructions
64+
Args:
65+
split_name: A train/test split name.
66+
dataset_dir: The base directory of the dataset sources.
67+
file_pattern: The file pattern to use when matching the dataset sources.
68+
It is assumed that the pattern contains a '%s' string so that the split
69+
name can be inserted.
70+
reader: The TensorFlow reader type.
71+
Returns:
72+
A `Dataset` namedtuple.
73+
Raises:
74+
ValueError: if `split_name` is not a valid train/test split.
75+
"""
76+
if split_name not in SPLITS_TO_SIZES:
77+
raise ValueError('split name %s was not recognized.' % split_name)
78+
79+
if not file_pattern:
80+
file_pattern = FILE_PATTERN
81+
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
82+
83+
# Allowing None in the signature so that dataset_factory can use the default.
84+
if reader is None:
85+
reader = tf.TFRecordReader
86+
# #文件名格式
87+
# if file_pattern is None:
88+
# file_pattern = _get_output_filename('tfrecords','voc_2007_train')#need fix your filename
89+
# print(file_pattern)
90+
91+
# 适配器1:将example反序列化成存储之前的格式。由tf完成
92+
keys_to_features = {
93+
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
94+
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
95+
'image/height': tf.FixedLenFeature([1], tf.int64),
96+
'image/width': tf.FixedLenFeature([1], tf.int64),
97+
'image/channels': tf.FixedLenFeature([1], tf.int64),
98+
'image/shape': tf.FixedLenFeature([3], tf.int64),
99+
'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
100+
'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
101+
'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
102+
'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
103+
'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
104+
'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
105+
'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
106+
}
107+
108+
#适配器2:将反序列化的数据组装成更高级的格式。由slim完成
109+
items_to_handlers = {
110+
'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
111+
'shape': slim.tfexample_decoder.Tensor('image/shape'),
112+
'object/bbox': slim.tfexample_decoder.BoundingBox(
113+
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
114+
'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
115+
'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
116+
'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
117+
}
118+
# 解码器
119+
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
120+
121+
# dataset对象定义了数据集的文件位置,解码方式等元信息
122+
dataset = slim.dataset.Dataset(
123+
data_sources=file_pattern,
124+
reader=reader,
125+
num_samples = SPLITS_TO_SIZES['test'], # 手动生成了三个文件, 每个文件里只包含一个example
126+
decoder=decoder,
127+
items_to_descriptions = ITEMS_TO_DESCRIPTIONS,
128+
num_classes=NUM_CLASSES)
129+
return dataset
130+
131+
#读取tfrecords文件
132+
def decode_from_tfrecords(filename,num_epoch=None):
133+
filename_queue=tf.train.string_input_producer([filename],num_epochs=num_epoch)#因为有的训练数据过于庞大,被分成了很多个文件,所以第一个参数就是文件列表名参数
134+
reader=tf.TFRecordReader()
135+
_,serialized=reader.read(filename_queue)
136+
example=tf.parse_single_example(serialized,features={
137+
'image/height':tf.FixedLenFeature([],tf.int64),
138+
'image/width':tf.FixedLenFeature([],tf.int64),
139+
'image/encoded':tf.FixedLenFeature([],tf.string),
140+
'image/object/class/label':tf.FixedLenFeature([],tf.int64)
141+
})
142+
label=tf.cast(example['image/object/class/label'], tf.int32)
143+
image=tf.decode_raw(example['image/encoded'],tf.uint8)
144+
image=tf.reshape(image,tf.stack([
145+
tf.cast(example['image/height'], tf.int32),
146+
tf.cast(example['image/width'], tf.int32),
147+
3]))
148+
149+
print('decode_from_tfrecords: ',image)
150+
print('decode_from_tfrecords: ',label)
151+
return image,label
152+
153+
def plt_bboxes(img, classes, scores, bboxes, figsize=(10,10), linewidth=1.5):
154+
"""Visualize bounding boxes. Largely inspired by SSD-MXNET!
155+
"""
156+
fig = plt.figure(figsize=figsize)
157+
plt.imshow(img)
158+
height = img.shape[0]
159+
width = img.shape[1]
160+
colors = dict()
161+
for i in range(classes.shape[0]):
162+
cls_id = int(classes[i])
163+
if cls_id >= 0:
164+
score = scores[i]
165+
if cls_id not in colors:
166+
colors[cls_id] = (random.random(), random.random(), random.random())
167+
ymin = int(bboxes[i, 0] * height)
168+
xmin = int(bboxes[i, 1] * width)
169+
ymax = int(bboxes[i, 2] * height)
170+
xmax = int(bboxes[i, 3] * width)
171+
# crop_img = img[xmin:(xmax - xmin),xmax:(ymax - ymin)]
172+
# misc.imsave('1.jpg', crop_img)
173+
rect = plt.Rectangle((xmin, ymin), xmax - xmin,
174+
ymax - ymin, fill=False,
175+
edgecolor=colors[cls_id],
176+
linewidth=linewidth)
177+
plt.gca().add_patch(rect)
178+
class_name = CLASSES[cls_id]
179+
plt.gca().text(xmin, ymin - 2,
180+
'{:s} | {:.3f}'.format(class_name, score),
181+
bbox=dict(facecolor=colors[cls_id], alpha=0.5),
182+
fontsize=12, color='white')
183+
plt.show()
184+
185+
def write_file(file_name_string,seg):
186+
with open(file_name_string, 'wb') as csvfile:
187+
188+
spamwriter = csv.writer(csvfile, dialect='excel')
189+
for i in range(seg.shape[0]):
190+
spamwriter.writerow(seg[i][:])
191+
def test():
192+
reconstructed_images = []
193+
record_iterator = tf.python_io.tf_record_iterator(path=
194+
'/home/robin/Dataset/VOC/VOC2012_VOCtrainval/sematic_segmentation_tfrecord/val-00001-of-00002.tfrecord')
195+
init=tf.global_variables_initializer()
196+
with tf.Session() as sess:
197+
sess.run(init)
198+
coord=tf.train.Coordinator()
199+
threads= tf.train.start_queue_runners(coord=coord)
200+
201+
202+
for string_iterator in record_iterator:
203+
plt.figure(figsize=(12, 12))
204+
example = tf.train.Example()
205+
example.ParseFromString(string_iterator)
206+
height = example.features.feature['image/height'].int64_list.value[0]
207+
width = example.features.feature['image/width'].int64_list.value[0]
208+
png_string = example.features.feature['image/encoded'].bytes_list.value[0]
209+
#label = example.features.feature['image/object/class/label'].int64_list.value[0]
210+
#xmin = example.features.feature['image/object/bbox/xmin'].float_list.value[0]
211+
#xmax = example.features.feature['image/object/bbox/xmax'].float_list.value[0]
212+
#ymin = example.features.feature['image/object/bbox/ymin'].float_list.value[0]
213+
#ymax = example.features.feature['image/object/bbox/ymax'].float_list.value[0]
214+
215+
encoded_mask_string = example.features.feature['image/segmentation/class/encoded'].bytes_list.value[0]
216+
217+
plt.subplot(131)
218+
mask_decode_png = tf.image.decode_png(encoded_mask_string, channels=1)
219+
fix_mask =tf.cast(tf.greater(mask_decode_png,0),tf.uint8)
220+
221+
222+
redecode_mask_img = sess.run(mask_decode_png)
223+
# write_file("mask.csv",redecode_mask_img)
224+
print(redecode_mask_img.shape)
225+
redecode_mask = redecode_mask_img * 255
226+
mask_img = np.squeeze(redecode_mask, axis = 2)
227+
plt.imshow(mask_img)
228+
plt.title('segmentation map')
229+
#im = Image.fromarray(mask_img)
230+
#im.save("pets.png")
231+
232+
plt.subplot(132)
233+
decoded_img = tf.image.decode_jpeg(png_string, channels=3)
234+
reconstructed_img = sess.run(decoded_img)
235+
print(reconstructed_img.shape)
236+
plt.imshow(reconstructed_img)
237+
plt.title('input image')
238+
239+
240+
plt.subplot(133)
241+
vis_util.draw_mask_on_image_array(
242+
image = reconstructed_img,
243+
mask = np.squeeze(sess.run(fix_mask), axis = 2),
244+
alpha=0.8)
245+
plt.imshow(reconstructed_img)
246+
plt.title('segmentation overlay')
247+
248+
plt.show()
249+
250+
251+
252+
coord.request_stop()
253+
coord.join(threads)
254+
test()
255+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Training script for the DeepLab model.
16+
17+
See model.py for more details and usage.
18+
"""
19+
20+
import six
21+
import tensorflow as tf
22+
23+
import sys
24+
import os
25+
import cv2
26+
import numpy as np
27+
import csv
28+
from matplotlib import pyplot as plt
29+
# This is needed since the notebook is stored in the object_detection folder.
30+
TF_API="/home/robin/eclipse-workspace-python/TF_models/models/research"
31+
sys.path.append(os.path.split(TF_API)[0])
32+
sys.path.append(TF_API)
33+
34+
from deeplab import common
35+
from deeplab import model
36+
from deeplab.datasets import segmentation_dataset
37+
from deeplab.utils import input_generator
38+
from deeplab.utils import train_utils
39+
from deployment import model_deploy
40+
from deeplab.utils import get_dataset_colormap
41+
42+
43+
slim = tf.contrib.slim
44+
45+
prefetch_queue = slim.prefetch_queue
46+
47+
flags = tf.app.flags
48+
49+
50+
# Settings for multi-GPUs/multi-replicas training.
51+
52+
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.')
53+
54+
flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')
55+
56+
flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.')
57+
58+
flags.DEFINE_integer('startup_delay_steps', 15,
59+
'Number of training steps between replicas startup.')
60+
61+
flags.DEFINE_integer('num_ps_tasks', 0,
62+
'The number of parameter servers. If the value is 0, then '
63+
'the parameters are handled locally by the worker.')
64+
65+
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
66+
67+
flags.DEFINE_integer('task', 0, 'The task ID.')
68+
69+
# Settings for logging.
70+
71+
72+
73+
# When fine_tune_batch_norm=True, use at least batch size larger than 12
74+
# (batch size more than 16 is better). Otherwise, one could use smaller batch
75+
# size and set fine_tune_batch_norm=False.
76+
flags.DEFINE_integer('train_batch_size', 2,
77+
'The number of images in each batch during training.')
78+
79+
# For weight_decay, use 0.00004 for MobileNet-V2 or Xcpetion model variants.
80+
# Use 0.0001 for ResNet model variants.
81+
flags.DEFINE_float('weight_decay', 0.00004,
82+
'The value of the weight decay for training.')
83+
84+
flags.DEFINE_multi_integer('train_crop_size', [513, 513],
85+
'Image crop size [height, width] during training.')
86+
87+
flags.DEFINE_float('last_layer_gradient_multiplier', 1.0,
88+
'The gradient multiplier for last layers, which is used to '
89+
'boost the gradient of last layers if the value > 1.')
90+
91+
flags.DEFINE_boolean('upsample_logits', True,
92+
'Upsample logits during training.')
93+
94+
# Settings for fine-tuning the network.
95+
96+
97+
flags.DEFINE_float('min_scale_factor', 0.5,
98+
'Mininum scale factor for data augmentation.')
99+
100+
flags.DEFINE_float('max_scale_factor', 2.,
101+
'Maximum scale factor for data augmentation.')
102+
103+
flags.DEFINE_float('scale_factor_step_size', 0.25,
104+
'Scale factor step size for data augmentation.')
105+
106+
# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
107+
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
108+
# one could use different atrous_rates/output_stride during training/evaluation.
109+
flags.DEFINE_multi_integer('atrous_rates', None,
110+
'Atrous rates for atrous spatial pyramid pooling.')
111+
112+
flags.DEFINE_integer('output_stride', 16,
113+
'The ratio of input to output spatial resolution.')
114+
115+
# Dataset settings.
116+
flags.DEFINE_string('dataset', 'pascal_voc_seg',
117+
'Name of the segmentation dataset.')
118+
119+
flags.DEFINE_string('train_split', 'train',
120+
'Which split of the dataset to be used for training')
121+
122+
flags.DEFINE_string('dataset_dir', "/home/robin/Dataset/VOC/VOC2012_VOCtrainval/sematic_segmentation_tfrecord", 'Where the dataset reside.')
123+
FLAGS = flags.FLAGS
124+
125+
126+
def write_file(file_name_string,seg):
127+
with open(file_name_string, 'wb') as csvfile:
128+
spamwriter = csv.writer(csvfile, dialect='excel')
129+
for i in range(seg.shape[0]):
130+
spamwriter.writerow(seg[i][:])
131+
132+
def main(unused_argv):
133+
tf.logging.set_verbosity(tf.logging.INFO)
134+
# Set up deployment (i.e., multi-GPUs and/or multi-replicas).
135+
config = model_deploy.DeploymentConfig(
136+
num_clones=FLAGS.num_clones,
137+
clone_on_cpu=FLAGS.clone_on_cpu,
138+
replica_id=FLAGS.task,
139+
num_replicas=FLAGS.num_replicas,
140+
num_ps_tasks=FLAGS.num_ps_tasks)
141+
142+
# Split the batch across GPUs.
143+
assert FLAGS.train_batch_size % config.num_clones == 0, (
144+
'Training batch size not divisble by number of clones (GPUs).')
145+
146+
clone_batch_size = FLAGS.train_batch_size // config.num_clones
147+
148+
# Get dataset-dependent information.
149+
dataset = segmentation_dataset.get_dataset(
150+
FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)
151+
152+
153+
154+
with tf.Graph().as_default() as graph:
155+
with tf.device(config.inputs_device()):
156+
samples = input_generator.get(
157+
dataset,
158+
FLAGS.train_crop_size,
159+
clone_batch_size,
160+
min_resize_value=FLAGS.min_resize_value,
161+
max_resize_value=FLAGS.max_resize_value,
162+
resize_factor=FLAGS.resize_factor,
163+
min_scale_factor=FLAGS.min_scale_factor,
164+
max_scale_factor=FLAGS.max_scale_factor,
165+
scale_factor_step_size=FLAGS.scale_factor_step_size,
166+
dataset_split=FLAGS.train_split,
167+
is_training=True,
168+
model_variant=FLAGS.model_variant)
169+
inputs_queue = prefetch_queue.prefetch_queue(
170+
samples, capacity=128 * config.num_clones)
171+
172+
173+
samples = inputs_queue.dequeue()
174+
175+
# Add name to input and label nodes so we can add to summary.
176+
samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE)
177+
samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL)
178+
179+
print(samples)
180+
181+
# Create the global step on the device storing the variables.
182+
with tf.device(config.variables_device()):
183+
global_step = tf.train.get_or_create_global_step()
184+
185+
186+
init=tf.global_variables_initializer()
187+
with tf.Session() as session:
188+
session.run(init)
189+
190+
coord = tf.train.Coordinator()
191+
threads = tf.train.start_queue_runners(coord=coord)
192+
193+
print('Start verification process...')
194+
try:
195+
while True:
196+
out_image, out_label = session.run([samples[common.IMAGE],samples[common.LABEL]])
197+
198+
#write_file("out_label.csv",np.squeeze(out_label[0], axis=2))
199+
200+
cv2.imshow('out_image',cv2.cvtColor(out_image[0]/255,cv2.COLOR_RGB2BGR))
201+
cv2.imshow('out_label',np.asarray(out_label[0]*100, dtype=np.uint8))
202+
203+
204+
205+
colored_label = get_dataset_colormap.label_to_color_image(np.squeeze(out_label[0]), dataset=get_dataset_colormap.get_pascal_name())
206+
cv2.imshow("colored_label",cv2.cvtColor(colored_label.astype(np.uint8),cv2.COLOR_RGB2BGR))
207+
208+
alpha = 0.5
209+
img_add = cv2.addWeighted(out_image[0], alpha, colored_label.astype(np.float32), 1-alpha, 0)
210+
cv2.imshow("colored_overlap",cv2.cvtColor(img_add,cv2.COLOR_RGB2BGR)/255)
211+
cv2.waitKey(0)
212+
213+
214+
except tf.errors.OutOfRangeError:
215+
print("end!")
216+
217+
coord.request_stop()
218+
coord.join(threads)
219+
220+
if __name__ == '__main__':
221+
222+
flags.mark_flag_as_required('dataset_dir')
223+
tf.app.run()

‎tf_record_mask_decode_demo.py

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Training script for the DeepLab model.
16+
17+
See model.py for more details and usage.
18+
"""
19+
20+
import six
21+
import tensorflow as tf
22+
23+
import sys
24+
import os
25+
import cv2
26+
import numpy as np
27+
28+
import numpy as np
29+
import PIL.Image as img
30+
31+
from matplotlib import gridspec
32+
from matplotlib import pyplot as plt
33+
34+
# This is needed since the notebook is stored in the object_detection folder.
35+
TF_API="/home/robin/eclipse-workspace-python/TF_models/models/research"
36+
sys.path.append(os.path.split(TF_API)[0])
37+
sys.path.append(TF_API)
38+
39+
from deeplab import common
40+
from deeplab import model
41+
from deeplab.datasets import segmentation_dataset
42+
from deeplab.utils import input_generator
43+
from deeplab.utils import train_utils
44+
from deeplab.utils import get_dataset_colormap
45+
from deployment import model_deploy
46+
from deeplab.core import preprocess_utils
47+
from deeplab import input_preprocess
48+
49+
slim = tf.contrib.slim
50+
51+
prefetch_queue = slim.prefetch_queue
52+
53+
flags = tf.app.flags
54+
55+
# Dataset settings.
56+
flags.DEFINE_string('dataset', 'ade20k',
57+
'Name of the segmentation dataset.')
58+
59+
flags.DEFINE_string('train_split', 'train',
60+
'Which split of the dataset to be used for training')
61+
62+
flags.DEFINE_string('dataset_dir', "/home/robin/Dataset/ADE20K/semantic_segmentation_tfrecord", 'Where the dataset reside.')
63+
FLAGS = flags.FLAGS
64+
65+
def vis_segmentation(image, seg_map):
66+
"""Visualizes input image, segmentation map and overlay view."""
67+
plt.figure(figsize=(16, 8))
68+
grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
69+
70+
plt.subplot(grid_spec[0])
71+
plt.imshow(image)
72+
plt.axis('off')
73+
plt.title('input image')
74+
75+
plt.subplot(grid_spec[1])
76+
plt.imshow(seg_map)
77+
plt.axis('off')
78+
plt.title('segmentation map')
79+
80+
plt.subplot(grid_spec[2])
81+
plt.imshow(image)
82+
plt.imshow(seg_map, alpha=0.8)
83+
plt.axis('off')
84+
plt.title('segmentation overlay')
85+
plt.show()
86+
87+
def main(unused_argv):
88+
tf.logging.set_verbosity(tf.logging.INFO)
89+
90+
# Get dataset-dependent information.
91+
dataset = segmentation_dataset.get_dataset(
92+
FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)
93+
94+
data_provider = slim.dataset_data_provider.DatasetDataProvider(
95+
dataset,
96+
num_readers=3,
97+
common_queue_capacity=20 * 1,
98+
common_queue_min=10 * 1,
99+
shuffle=False)
100+
image, label, image_name, height, width = input_generator.get_decode_data(data_provider,
101+
FLAGS.train_split)
102+
print(image, label, image_name, height, width)
103+
104+
original_image, processed_image, label = input_preprocess.preprocess_image_and_label(
105+
image,
106+
label,
107+
crop_height=513,
108+
crop_width=513,
109+
min_resize_value=513,
110+
max_resize_value=513,
111+
resize_factor=None,
112+
min_scale_factor=0.5,
113+
max_scale_factor=2,
114+
scale_factor_step_size=0.25,
115+
ignore_label=0,
116+
is_training=True,
117+
model_variant="mobilenet_v2")
118+
119+
120+
init=tf.global_variables_initializer()
121+
with tf.Session() as session:
122+
session.run(init)
123+
124+
coord = tf.train.Coordinator()
125+
threads = tf.train.start_queue_runners(coord=coord)
126+
127+
print('Start verification process...')
128+
for l in range(data_provider._num_samples):
129+
out_image, out_label, out_image_name, out_height, out_width = session.run([processed_image, label, image_name, height, width])
130+
131+
# print(out_label.shape)
132+
# print(out_image ,out_label.shape ,out_height, out_width)
133+
# print(out_image, out_label, out_image_name, out_height, out_width)
134+
135+
136+
# vis_segmentation(out_image/255, np.squeeze(out_label, axis=2))
137+
138+
colored_label = get_dataset_colormap.label_to_color_image(np.squeeze(out_label, axis=2), dataset=get_dataset_colormap.get_ade20k_name())
139+
colored_label_uint8 = np.asarray(colored_label, dtype=np.uint8)
140+
cv2.imshow("colored_label",cv2.cvtColor(colored_label_uint8,cv2.COLOR_RGB2BGR))
141+
142+
143+
colored_label = colored_label.astype(np.float32) #np.asarray(colored_label, dtype=np.float32)
144+
alpha = 0.3
145+
img_add = img_add = cv2.addWeighted(out_image, alpha, colored_label, 1-alpha, 0)
146+
cv2.imshow("colored_overlap",cv2.cvtColor(img_add,cv2.COLOR_RGB2BGR)/255)
147+
cv2.waitKey(0)
148+
149+
150+
coord.request_stop()
151+
coord.join(threads)
152+
153+
154+
if __name__ == '__main__':
155+
# flags.mark_flag_as_required('train_logdir')
156+
# flags.mark_flag_as_required('tf_initial_checkpoint')
157+
flags.mark_flag_as_required('dataset_dir')
158+
tf.app.run()

0 commit comments

Comments
 (0)
Please sign in to comment.