-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_util.py
99 lines (84 loc) · 3.97 KB
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Author: Wentao Yuan ([email protected]) 05/31/2018
import numpy as np
import tensorflow as tf
from tensorpack import dataflow
def resample_pcd(pcd, n):
"""Drop or duplicate points so that pcd has exactly n points"""
idx= np.array(range(pcd.shape[0]))
if idx.shape[0] < n:
idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
return pcd[idx[:n]]
class PreprocessData(dataflow.ProxyDataFlow):
def __init__(self, ds, input_size, output_size):
super(PreprocessData, self).__init__(ds)
self.input_size = input_size
self.output_size = output_size
def get_data(self):
for id, input, gt in self.ds.get_data():
input = resample_pcd(input, self.input_size)
gt = resample_pcd(gt, self.output_size)
yield id, input, gt
class BatchData(dataflow.ProxyDataFlow):
def __init__(self, ds, batch_size, input_size, gt_size, remainder=False, use_list=False):
super(BatchData, self).__init__(ds)
self.batch_size = batch_size
self.input_size = input_size
self.gt_size = gt_size
self.remainder = remainder
self.use_list = use_list
def __len__(self):
ds_size = len(self.ds)
div = ds_size // self.batch_size
rem = ds_size % self.batch_size
if rem == 0:
return div
return div + int(self.remainder)
def __iter__(self):
holder = []
for data in self.ds:
holder.append(data)
if len(holder) == self.batch_size:
yield self._aggregate_batch(holder, self.use_list)
del holder[:]
if self.remainder and len(holder) > 0:
yield self._aggregate_batch(holder, self.use_list)
def _aggregate_batch(self, data_holder, use_list=False):
''' Concatenate input points along the 0-th dimension
Stack all other data along the 0-th dimension
'''
ids = np.stack([x[0] for x in data_holder])
#inputs = [resample_pcd(x[1], self.input_size) if x[1].shape[0] > self.input_size else x[1]
# for x in data_holder]
#print(np.shape(inputs[0]))
#inputs = np.expand_dims(np.concatenate([x for x in inputs]), 0).astype(np.float32)
inputs = np.stack([resample_pcd(x[1], self.input_size) for x in data_holder]).astype(np.float32)
npts=self.input_size
#npts = np.stack([x[1].shape[0] if x[1].shape[0] < self.input_size else self.input_size
# for x in data_holder]).astype(np.int32)
gts = np.stack([resample_pcd(x[2], self.gt_size) for x in data_holder]).astype(np.float32)
return ids, inputs, npts, gts
def lmdb_dataflow(lmdb_path, batch_size, input_size, output_size, is_training, test_speed=False):
df = dataflow.LMDBSerializer.load(lmdb_path, shuffle=False)
#df=dataflow.LMDBData(lmdb_path,shuffle=False)
size = df.size()
if is_training:
df = dataflow.LocallyShuffleData(df, buffer_size=2000)
df = dataflow.PrefetchData(df, num_prefetch=500, num_proc=1)
df = BatchData(df, batch_size, input_size, output_size)
if is_training:
df = dataflow.PrefetchDataZMQ(df, num_proc=8)
df = dataflow.RepeatedData(df, -1)
if test_speed:
dataflow.TestDataSpeed(df, size=1000).start()
df.reset_state()
return df, size
def get_queued_data(generator, dtypes, shapes, queue_capacity=10):
assert len(dtypes) == len(shapes), 'dtypes and shapes must have the same length'
queue = tf.FIFOQueue(queue_capacity, dtypes, shapes)
placeholders = [tf.placeholder(dtype, shape) for dtype, shape in zip(dtypes, shapes)]
enqueue_op = queue.enqueue(placeholders)
close_op = queue.close(cancel_pending_enqueues=True)
feed_fn = lambda: {placeholder: value for placeholder, value in zip(placeholders, next(generator))}
queue_runner = tf.contrib.training.FeedingQueueRunner(queue, [enqueue_op], close_op, feed_fns=[feed_fn])
tf.train.add_queue_runner(queue_runner)
return queue.dequeue()