|
| 1 | +#coding=utf-8 |
| 2 | +import tensorflow as tf |
| 3 | +import sys |
| 4 | +import cv2 |
| 5 | +import os |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +from PIL import Image |
| 8 | +import numpy as np |
| 9 | +import csv |
| 10 | + |
| 11 | +# This is needed since the notebook is stored in the object_detection folder. |
| 12 | +TF_API="/home/robin/eclipse-workspace-python/TF_models/models/research" |
| 13 | +sys.path.append(os.path.split(TF_API)[0]) |
| 14 | +sys.path.append(TF_API) |
| 15 | + |
| 16 | +from object_detection.utils import visualization_utils as vis_util |
| 17 | +slim = tf.contrib.slim |
| 18 | + |
| 19 | + |
| 20 | +NUM_CLASSES = 20 |
| 21 | +SPLITS_TO_SIZES = { |
| 22 | + 'train': 5011, |
| 23 | + 'test': 4952, |
| 24 | +} |
| 25 | +ITEMS_TO_DESCRIPTIONS = { |
| 26 | + 'image': 'A color image of varying height and width.', |
| 27 | + 'shape': 'Shape of the image', |
| 28 | + 'object/bbox': 'A list of bounding boxes, one per each object.', |
| 29 | + 'object/label': 'A list of labels, one per each object.', |
| 30 | +} |
| 31 | + |
| 32 | +labels_to_class =['none','aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', |
| 33 | + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', |
| 34 | + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', |
| 35 | + 'train', 'tvmonitor'] |
| 36 | + |
| 37 | +FILE_PATTERN = 'voc_2007_%s_*.tfrecord' |
| 38 | + |
| 39 | +def _get_output_filename(dataset_dir, split_name): |
| 40 | + """Creates the output filename. |
| 41 | + Args: |
| 42 | + dataset_dir: The dataset directory where the dataset is stored. |
| 43 | + split_name: The name of the train/test split. |
| 44 | + Returns: |
| 45 | + An absolute file path. |
| 46 | + """ |
| 47 | + return '%s/%s*.tfrecord' % (dataset_dir, split_name) |
| 48 | + |
| 49 | +def bboxes_draw_on_img(img, classes, bboxes, colors, thickness=2): |
| 50 | + shape = img.shape |
| 51 | + for i in range(bboxes.shape[0]): |
| 52 | + bbox = bboxes[i] |
| 53 | + # Draw bounding box... |
| 54 | + p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1])) |
| 55 | + p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1])) |
| 56 | + cv2.rectangle(img, p1[::-1], p2[::-1], colors, thickness) |
| 57 | + # Draw text... |
| 58 | + s = '%s' % (labels_to_class[classes[i]]) |
| 59 | + p1 = (p1[0]+15, p1[1]+5) |
| 60 | + cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.4, colors, 1) |
| 61 | + |
| 62 | +def get_split(split_name, dataset_dir, file_pattern=None, reader=None): |
| 63 | + """Gets a dataset tuple with instructions |
| 64 | + Args: |
| 65 | + split_name: A train/test split name. |
| 66 | + dataset_dir: The base directory of the dataset sources. |
| 67 | + file_pattern: The file pattern to use when matching the dataset sources. |
| 68 | + It is assumed that the pattern contains a '%s' string so that the split |
| 69 | + name can be inserted. |
| 70 | + reader: The TensorFlow reader type. |
| 71 | + Returns: |
| 72 | + A `Dataset` namedtuple. |
| 73 | + Raises: |
| 74 | + ValueError: if `split_name` is not a valid train/test split. |
| 75 | + """ |
| 76 | + if split_name not in SPLITS_TO_SIZES: |
| 77 | + raise ValueError('split name %s was not recognized.' % split_name) |
| 78 | + |
| 79 | + if not file_pattern: |
| 80 | + file_pattern = FILE_PATTERN |
| 81 | + file_pattern = os.path.join(dataset_dir, file_pattern % split_name) |
| 82 | + |
| 83 | + # Allowing None in the signature so that dataset_factory can use the default. |
| 84 | + if reader is None: |
| 85 | + reader = tf.TFRecordReader |
| 86 | +# #文件名格式 |
| 87 | +# if file_pattern is None: |
| 88 | +# file_pattern = _get_output_filename('tfrecords','voc_2007_train')#need fix your filename |
| 89 | +# print(file_pattern) |
| 90 | + |
| 91 | + # 适配器1:将example反序列化成存储之前的格式。由tf完成 |
| 92 | + keys_to_features = { |
| 93 | + 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), |
| 94 | + 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), |
| 95 | + 'image/height': tf.FixedLenFeature([1], tf.int64), |
| 96 | + 'image/width': tf.FixedLenFeature([1], tf.int64), |
| 97 | + 'image/channels': tf.FixedLenFeature([1], tf.int64), |
| 98 | + 'image/shape': tf.FixedLenFeature([3], tf.int64), |
| 99 | + 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), |
| 100 | + 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), |
| 101 | + 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), |
| 102 | + 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), |
| 103 | + 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64), |
| 104 | + 'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64), |
| 105 | + 'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64), |
| 106 | + } |
| 107 | + |
| 108 | + #适配器2:将反序列化的数据组装成更高级的格式。由slim完成 |
| 109 | + items_to_handlers = { |
| 110 | + 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), |
| 111 | + 'shape': slim.tfexample_decoder.Tensor('image/shape'), |
| 112 | + 'object/bbox': slim.tfexample_decoder.BoundingBox( |
| 113 | + ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), |
| 114 | + 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'), |
| 115 | + 'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'), |
| 116 | + 'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'), |
| 117 | + } |
| 118 | + # 解码器 |
| 119 | + decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) |
| 120 | + |
| 121 | + # dataset对象定义了数据集的文件位置,解码方式等元信息 |
| 122 | + dataset = slim.dataset.Dataset( |
| 123 | + data_sources=file_pattern, |
| 124 | + reader=reader, |
| 125 | + num_samples = SPLITS_TO_SIZES['test'], # 手动生成了三个文件, 每个文件里只包含一个example |
| 126 | + decoder=decoder, |
| 127 | + items_to_descriptions = ITEMS_TO_DESCRIPTIONS, |
| 128 | + num_classes=NUM_CLASSES) |
| 129 | + return dataset |
| 130 | + |
| 131 | +#读取tfrecords文件 |
| 132 | +def decode_from_tfrecords(filename,num_epoch=None): |
| 133 | + filename_queue=tf.train.string_input_producer([filename],num_epochs=num_epoch)#因为有的训练数据过于庞大,被分成了很多个文件,所以第一个参数就是文件列表名参数 |
| 134 | + reader=tf.TFRecordReader() |
| 135 | + _,serialized=reader.read(filename_queue) |
| 136 | + example=tf.parse_single_example(serialized,features={ |
| 137 | + 'image/height':tf.FixedLenFeature([],tf.int64), |
| 138 | + 'image/width':tf.FixedLenFeature([],tf.int64), |
| 139 | + 'image/encoded':tf.FixedLenFeature([],tf.string), |
| 140 | + 'image/object/class/label':tf.FixedLenFeature([],tf.int64) |
| 141 | + }) |
| 142 | + label=tf.cast(example['image/object/class/label'], tf.int32) |
| 143 | + image=tf.decode_raw(example['image/encoded'],tf.uint8) |
| 144 | + image=tf.reshape(image,tf.stack([ |
| 145 | + tf.cast(example['image/height'], tf.int32), |
| 146 | + tf.cast(example['image/width'], tf.int32), |
| 147 | + 3])) |
| 148 | + |
| 149 | + print('decode_from_tfrecords: ',image) |
| 150 | + print('decode_from_tfrecords: ',label) |
| 151 | + return image,label |
| 152 | + |
| 153 | +def plt_bboxes(img, classes, scores, bboxes, figsize=(10,10), linewidth=1.5): |
| 154 | + """Visualize bounding boxes. Largely inspired by SSD-MXNET! |
| 155 | + """ |
| 156 | + fig = plt.figure(figsize=figsize) |
| 157 | + plt.imshow(img) |
| 158 | + height = img.shape[0] |
| 159 | + width = img.shape[1] |
| 160 | + colors = dict() |
| 161 | + for i in range(classes.shape[0]): |
| 162 | + cls_id = int(classes[i]) |
| 163 | + if cls_id >= 0: |
| 164 | + score = scores[i] |
| 165 | + if cls_id not in colors: |
| 166 | + colors[cls_id] = (random.random(), random.random(), random.random()) |
| 167 | + ymin = int(bboxes[i, 0] * height) |
| 168 | + xmin = int(bboxes[i, 1] * width) |
| 169 | + ymax = int(bboxes[i, 2] * height) |
| 170 | + xmax = int(bboxes[i, 3] * width) |
| 171 | +# crop_img = img[xmin:(xmax - xmin),xmax:(ymax - ymin)] |
| 172 | +# misc.imsave('1.jpg', crop_img) |
| 173 | + rect = plt.Rectangle((xmin, ymin), xmax - xmin, |
| 174 | + ymax - ymin, fill=False, |
| 175 | + edgecolor=colors[cls_id], |
| 176 | + linewidth=linewidth) |
| 177 | + plt.gca().add_patch(rect) |
| 178 | + class_name = CLASSES[cls_id] |
| 179 | + plt.gca().text(xmin, ymin - 2, |
| 180 | + '{:s} | {:.3f}'.format(class_name, score), |
| 181 | + bbox=dict(facecolor=colors[cls_id], alpha=0.5), |
| 182 | + fontsize=12, color='white') |
| 183 | + plt.show() |
| 184 | + |
| 185 | +def write_file(file_name_string,seg): |
| 186 | + with open(file_name_string, 'wb') as csvfile: |
| 187 | + |
| 188 | + spamwriter = csv.writer(csvfile, dialect='excel') |
| 189 | + for i in range(seg.shape[0]): |
| 190 | + spamwriter.writerow(seg[i][:]) |
| 191 | +def test(): |
| 192 | + reconstructed_images = [] |
| 193 | + record_iterator = tf.python_io.tf_record_iterator(path= |
| 194 | + '/home/robin/Dataset/VOC/VOC2012_VOCtrainval/sematic_segmentation_tfrecord/val-00001-of-00002.tfrecord') |
| 195 | + init=tf.global_variables_initializer() |
| 196 | + with tf.Session() as sess: |
| 197 | + sess.run(init) |
| 198 | + coord=tf.train.Coordinator() |
| 199 | + threads= tf.train.start_queue_runners(coord=coord) |
| 200 | + |
| 201 | + |
| 202 | + for string_iterator in record_iterator: |
| 203 | + plt.figure(figsize=(12, 12)) |
| 204 | + example = tf.train.Example() |
| 205 | + example.ParseFromString(string_iterator) |
| 206 | + height = example.features.feature['image/height'].int64_list.value[0] |
| 207 | + width = example.features.feature['image/width'].int64_list.value[0] |
| 208 | + png_string = example.features.feature['image/encoded'].bytes_list.value[0] |
| 209 | + #label = example.features.feature['image/object/class/label'].int64_list.value[0] |
| 210 | + #xmin = example.features.feature['image/object/bbox/xmin'].float_list.value[0] |
| 211 | + #xmax = example.features.feature['image/object/bbox/xmax'].float_list.value[0] |
| 212 | + #ymin = example.features.feature['image/object/bbox/ymin'].float_list.value[0] |
| 213 | + #ymax = example.features.feature['image/object/bbox/ymax'].float_list.value[0] |
| 214 | + |
| 215 | + encoded_mask_string = example.features.feature['image/segmentation/class/encoded'].bytes_list.value[0] |
| 216 | + |
| 217 | + plt.subplot(131) |
| 218 | + mask_decode_png = tf.image.decode_png(encoded_mask_string, channels=1) |
| 219 | + fix_mask =tf.cast(tf.greater(mask_decode_png,0),tf.uint8) |
| 220 | + |
| 221 | + |
| 222 | + redecode_mask_img = sess.run(mask_decode_png) |
| 223 | +# write_file("mask.csv",redecode_mask_img) |
| 224 | + print(redecode_mask_img.shape) |
| 225 | + redecode_mask = redecode_mask_img * 255 |
| 226 | + mask_img = np.squeeze(redecode_mask, axis = 2) |
| 227 | + plt.imshow(mask_img) |
| 228 | + plt.title('segmentation map') |
| 229 | + #im = Image.fromarray(mask_img) |
| 230 | + #im.save("pets.png") |
| 231 | + |
| 232 | + plt.subplot(132) |
| 233 | + decoded_img = tf.image.decode_jpeg(png_string, channels=3) |
| 234 | + reconstructed_img = sess.run(decoded_img) |
| 235 | + print(reconstructed_img.shape) |
| 236 | + plt.imshow(reconstructed_img) |
| 237 | + plt.title('input image') |
| 238 | + |
| 239 | + |
| 240 | + plt.subplot(133) |
| 241 | + vis_util.draw_mask_on_image_array( |
| 242 | + image = reconstructed_img, |
| 243 | + mask = np.squeeze(sess.run(fix_mask), axis = 2), |
| 244 | + alpha=0.8) |
| 245 | + plt.imshow(reconstructed_img) |
| 246 | + plt.title('segmentation overlay') |
| 247 | + |
| 248 | + plt.show() |
| 249 | + |
| 250 | + |
| 251 | + |
| 252 | + coord.request_stop() |
| 253 | + coord.join(threads) |
| 254 | +test() |
| 255 | + |
0 commit comments