-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathkitti2015_dataset.py
116 lines (99 loc) · 5.66 KB
/
kitti2015_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import random
import cv2
import torch
from .crd_fusion_dataset import CRDFusionDataset
class Kitti2015Dataset(CRDFusionDataset):
def __init__(self, data_path, max_disp, downscale, resized_height, resized_width, conf_thres, is_train,
imgnet_norm=True, sanity=False):
"""
Dataset to load and prepare data for training/validation from KITTI 2015 dataset
:param data_path: directory to the dataset
:param max_disp: maximum disparity before downscaling
:param downscale: downscaling factor
:param resized_height: final image height after downscaling and resizing
:param resized_width: final image width after downscaling and resizing
:param conf_thres: threshold for confidence score
:param is_train: flag to indicate if this dataset is for training or not
:param imgnet_norm: if set to True, the RGB images will be normalized by ImageNet's statistics
:param sanity: if set to True, only includes 1 data point. Mostly used to debug the model
"""
super(Kitti2015Dataset, self).__init__(data_path, max_disp, downscale, resized_height, resized_width,
conf_thres, is_train, imgnet_norm, sanity)
self.data_path = os.path.join(self.data_path, "training")
if self.is_train:
with open(os.path.join(self.data_path, "train.txt")) as f:
self.data_list = f.readlines()
else:
with open(os.path.join(self.data_path, "val.txt")) as f:
self.data_list = f.readlines()
self.data_list = [d.strip("\n") for d in self.data_list]
if self.sanity: # only keep the first entry for sanity check
self.data_list.sort()
self.data_list = [self.data_list[0]]
def _get_gt_disp(self, disp_path):
"""
Get the ground truth disparity
:param disp_path: directory to the ground truth disparity
:return: ground truth disparity as a PyTorch tensor
"""
disp = cv2.imread(disp_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
disp = disp / 256.0
disp = self.to_tensor(disp).float()
disp[disp != disp] = 0 # set all pixels with NaN to zero
disp[disp == float('inf')] = 0 # set all pixels with inf or -inf to zero
disp[disp >= self.max_disp] = 0 # set all disparity larger than the preset maximum to 0
disp /= self.downscale
return disp
def __getitem__(self, index):
"""
Get a data sample
:param index: index for the data list
:return: a stack of input data in tensor form including left rgb, right rgb, raw disparity, confidence mask, \
frame id and ground truth disparity if the dataset is for validation
"""
frame = self.data_list[index]
do_color_aug = self.is_train and random.random() > 0.5 and (not self.sanity)
raw_inputs = {}
l_rgb_path = os.path.join(self.data_path, "image_2", frame)
r_rgb_path = os.path.join(self.data_path, "image_3", frame)
disp_path = os.path.join(self.data_path, "raw_disp", frame.replace(".png", ".npy"))
conf_path = os.path.join(self.data_path, "conf", frame.replace(".png", ".npy"))
raw_inputs['l_rgb'] = self._get_rgb(l_rgb_path)
raw_inputs['r_rgb'] = self._get_rgb(r_rgb_path)
raw_inputs['raw_disp'] = self._get_disp(disp_path)
raw_inputs['mask'] = self._get_conf(conf_path)
# Need to override orig_height and orig_width for KITTI since the image size may vary in the dataset
_, self.orig_height, self.orig_width = raw_inputs['l_rgb'].size()
assert self.orig_width % self.downscale == 0 and self.orig_height % self.downscale == 0, \
"original image size not divisible by downscaling factor"
# if not self.is_train:
gt_occ_disp_path = os.path.join(self.data_path, "disp_occ_0", frame)
gt_noc_disp_path = os.path.join(self.data_path, "disp_noc_0", frame)
raw_inputs['gt_disp'] = self._get_gt_disp(gt_occ_disp_path)
raw_inputs['noc_gt_disp'] = self._get_gt_disp(gt_noc_disp_path)
if ((self.orig_width // self.downscale - self.resized_width) >= 0 and (
self.orig_height // self.downscale - self.resized_height) > 0) or (
(self.orig_width // self.downscale - self.resized_width) > 0 and (
self.orig_height // self.downscale - self.resized_height) >= 0):
inputs = self._crop_inputs(raw_inputs)
elif ((self.orig_width // self.downscale - self.resized_width) <= 0 and (
self.orig_height // self.downscale - self.resized_height) < 0) or (
(self.orig_width // self.downscale - self.resized_width) < 0 and (
self.orig_height // self.downscale - self.resized_height) <= 0):
inputs = self._pad_inputs(raw_inputs)
else:
print("Inconsistent image resizing scheme")
raise RuntimeError
if not self.is_train: # for conf generation in predict_kitti.py
inputs['raw_disp_non_norm'] = torch.clone(inputs['raw_disp'])
inputs['l_rgb_non_norm'] = torch.clone(inputs['l_rgb'])
inputs['r_rgb_non_norm'] = torch.clone(inputs['r_rgb'])
inputs['raw_disp'] = self._normalize_disp(inputs['raw_disp'])
if do_color_aug:
inputs['l_rgb'], inputs['r_rgb'] = self._data_augmentation(inputs['l_rgb'], inputs['r_rgb'])
if self.imgnet_norm:
inputs['l_rgb'], inputs['r_rgb'] = self._normalize_rgb(inputs['l_rgb'], inputs['r_rgb'])
# for saving predicted disparity
inputs['frame_id'] = frame
return inputs