-
Notifications
You must be signed in to change notification settings - Fork 74
/
donkey.py
49 lines (42 loc) · 2.39 KB
/
donkey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import tensorflow as tf
class Donkey(object):
@staticmethod
def _preprocess(image):
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.multiply(tf.subtract(image, 0.5), 2)
image = tf.reshape(image, [64, 64, 3])
image = tf.random_crop(image, [54, 54, 3])
return image
@staticmethod
def _read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'length': tf.FixedLenFeature([], tf.int64),
'digits': tf.FixedLenFeature([5], tf.int64)
})
image = Donkey._preprocess(tf.decode_raw(features['image'], tf.uint8))
length = tf.cast(features['length'], tf.int32)
digits = tf.cast(features['digits'], tf.int32)
return image, length, digits
@staticmethod
def build_batch(path_to_tfrecords_file, num_examples, batch_size, shuffled):
assert tf.gfile.Exists(path_to_tfrecords_file), '%s not found' % path_to_tfrecords_file
filename_queue = tf.train.string_input_producer([path_to_tfrecords_file], num_epochs=None)
image, length, digits = Donkey._read_and_decode(filename_queue)
min_queue_examples = int(0.4 * num_examples)
if shuffled:
image_batch, length_batch, digits_batch = tf.train.shuffle_batch([image, length, digits],
batch_size=batch_size,
num_threads=2,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
image_batch, length_batch, digits_batch = tf.train.batch([image, length, digits],
batch_size=batch_size,
num_threads=2,
capacity=min_queue_examples + 3 * batch_size)
return image_batch, length_batch, digits_batch