Skip to content

Commit 72c608c

Browse files
committed
Add support for DSVT
1 parent 02ac3e1 commit 72c608c

28 files changed

+1987
-22
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
2323

2424

2525
## Changelog
26+
[2023-06-xx] **NEW:** Added support for [`DSVT`](https://arxiv.org/abs/2301.06051), which achieves state-of-the-art performance on large-scale Waymo Open Dataset with real-time inference speed (27HZ with TensorRt).
27+
2628
[2023-05-13] **NEW:** Added support for the multi-modal 3D object detection models on Nuscenes dataset.
2729
* Support multi-modal Nuscenes detection (See the [GETTING_STARTED.md](docs/GETTING_STARTED.md) to process data).
2830
* Support [TransFusion-Lidar](https://arxiv.org/abs/2203.11496) head, which ahcieves 69.43% NDS on Nuscenes validation dataset.
@@ -192,6 +194,8 @@ Here we also provide the performance of several models trained on the full train
192194
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79 |
193195
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 |
194196
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) | 79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19 |
197+
| [DSVT-Pillar](tools/cfgs/waymo_models/dsvt_pillar.yaml) | 79.44/78.97 | 71.24/70.81 | 83.00/77.22 | 75.45/69.95 | 76.70/75.70 | 73.83/72.86 |
198+
| [DSVT-Voxel](tools/cfgs/waymo_models/dsvt_voxel.yaml) | 79.77/79.31 | 71.67/71.25 | 83.75/78.92 | 76.21/71.57 | 77.57/76.58 | 74.70/73.73 |
195199
| [PV-RCNN++ (ResNet, 2 frames)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet_2frames.yaml) | 80.17/79.70 | 72.14/71.70 | 83.48/80.42 | 75.54/72.61 | 74.63/73.75 | 72.35/71.50 |
196200
| [MPPNet (4 frames)](docs/guidelines_of_approaches/mppnet.md) | 81.54/81.06 | 74.07/73.61 | 84.56/81.94 | 77.20/74.67 | 77.15/76.50 | 75.01/74.38 |
197201
| [MPPNet (16 frames)](docs/guidelines_of_approaches/mppnet.md) | 82.74/82.28 | 75.41/74.96 | 84.69/82.25 | 77.43/75.06 | 77.28/76.66 | 75.13/74.52 |
@@ -201,6 +205,7 @@ Here we also provide the performance of several models trained on the full train
201205

202206

203207

208+
204209
We could not provide the above pretrained models due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/),
205210
but you could easily achieve similar performance by training with the default configs.
206211

pcdet/datasets/waymo/waymo_dataset.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,11 @@ def get_lidar(self, sequence_name, sample_idx):
200200
points_all, NLZ_flag = point_features[:, 0:5], point_features[:, 5]
201201
if not self.dataset_cfg.get('DISABLE_NLZ_FLAG_ON_POINTS', False):
202202
points_all = points_all[NLZ_flag == -1]
203-
points_all[:, 3] = np.tanh(points_all[:, 3])
203+
if self.dataset_cfg.get('POINTS_TANH_DIM', None) is None:
204+
points_all[:, 3] = np.tanh(points_all[:, 3])
205+
else:
206+
for dim_idx in self.dataset_cfg.POINTS_TANH_DIM:
207+
points_all[:, dim_idx] = np.tanh(points_all[:, dim_idx])
204208
return points_all
205209

206210
@staticmethod

pcdet/models/backbones_2d/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from .base_bev_backbone import BaseBEVBackbone, BaseBEVBackboneV1
1+
from .base_bev_backbone import BaseBEVBackbone, BaseBEVBackboneV1, BaseBEVResBackbone
22

33
__all__ = {
44
'BaseBEVBackbone': BaseBEVBackbone,
5-
'BaseBEVBackboneV1': BaseBEVBackboneV1
5+
'BaseBEVBackboneV1': BaseBEVBackboneV1,
6+
'BaseBEVResBackbone': BaseBEVResBackbone,
67
}

pcdet/models/backbones_2d/base_bev_backbone.py

+147
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,150 @@ def forward(self, data_dict):
202202
data_dict['spatial_features_2d'] = x
203203

204204
return data_dict
205+
206+
207+
class BasicBlock(nn.Module):
208+
expansion: int = 1
209+
210+
def __init__(
211+
self,
212+
inplanes: int,
213+
planes: int,
214+
stride: int = 1,
215+
padding: int = 1,
216+
downsample: bool = False,
217+
) -> None:
218+
super().__init__()
219+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=padding, bias=False)
220+
self.bn1 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
221+
self.relu1 = nn.ReLU()
222+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
223+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
224+
self.relu2 = nn.ReLU()
225+
self.downsample = downsample
226+
if self.downsample:
227+
self.downsample_layer = nn.Sequential(
228+
nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False),
229+
nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
230+
)
231+
self.stride = stride
232+
233+
def forward(self, x):
234+
identity = x
235+
236+
out = self.conv1(x)
237+
out = self.bn1(out)
238+
out = self.relu1(out)
239+
240+
out = self.conv2(out)
241+
out = self.bn2(out)
242+
243+
if self.downsample:
244+
identity = self.downsample_layer(x)
245+
246+
out += identity
247+
out = self.relu2(out)
248+
249+
return out
250+
251+
252+
class BaseBEVResBackbone(nn.Module):
253+
def __init__(self, model_cfg, input_channels):
254+
super().__init__()
255+
self.model_cfg = model_cfg
256+
257+
if self.model_cfg.get('LAYER_NUMS', None) is not None:
258+
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
259+
layer_nums = self.model_cfg.LAYER_NUMS
260+
layer_strides = self.model_cfg.LAYER_STRIDES
261+
num_filters = self.model_cfg.NUM_FILTERS
262+
else:
263+
layer_nums = layer_strides = num_filters = []
264+
265+
if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:
266+
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
267+
num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
268+
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
269+
else:
270+
upsample_strides = num_upsample_filters = []
271+
272+
num_levels = len(layer_nums)
273+
c_in_list = [input_channels, *num_filters[:-1]]
274+
self.blocks = nn.ModuleList()
275+
self.deblocks = nn.ModuleList()
276+
for idx in range(num_levels):
277+
cur_layers = [
278+
# nn.ZeroPad2d(1),
279+
BasicBlock(c_in_list[idx], num_filters[idx], layer_strides[idx], 1, True)
280+
]
281+
for k in range(layer_nums[idx]):
282+
cur_layers.extend([
283+
BasicBlock(num_filters[idx], num_filters[idx])
284+
])
285+
self.blocks.append(nn.Sequential(*cur_layers))
286+
if len(upsample_strides) > 0:
287+
stride = upsample_strides[idx]
288+
if stride >= 1:
289+
self.deblocks.append(nn.Sequential(
290+
nn.ConvTranspose2d(
291+
num_filters[idx], num_upsample_filters[idx],
292+
upsample_strides[idx],
293+
stride=upsample_strides[idx], bias=False
294+
),
295+
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
296+
nn.ReLU()
297+
))
298+
else:
299+
stride = np.round(1 / stride).astype(np.int)
300+
self.deblocks.append(nn.Sequential(
301+
nn.Conv2d(
302+
num_filters[idx], num_upsample_filters[idx],
303+
stride,
304+
stride=stride, bias=False
305+
),
306+
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
307+
nn.ReLU()
308+
))
309+
310+
c_in = sum(num_upsample_filters) if len(num_upsample_filters) > 0 else sum(num_filters)
311+
if len(upsample_strides) > num_levels:
312+
self.deblocks.append(nn.Sequential(
313+
nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),
314+
nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
315+
nn.ReLU(),
316+
))
317+
318+
self.num_bev_features = c_in
319+
320+
def forward(self, data_dict):
321+
"""
322+
Args:
323+
data_dict:
324+
spatial_features
325+
Returns:
326+
"""
327+
spatial_features = data_dict['spatial_features']
328+
ups = []
329+
ret_dict = {}
330+
x = spatial_features
331+
for i in range(len(self.blocks)):
332+
x = self.blocks[i](x)
333+
334+
stride = int(spatial_features.shape[2] / x.shape[2])
335+
ret_dict['spatial_features_%dx' % stride] = x
336+
if len(self.deblocks) > 0:
337+
ups.append(self.deblocks[i](x))
338+
else:
339+
ups.append(x)
340+
341+
if len(ups) > 1:
342+
x = torch.cat(ups, dim=1)
343+
elif len(ups) == 1:
344+
x = ups[0]
345+
346+
if len(self.deblocks) > len(self.blocks):
347+
x = self.deblocks[-1](x)
348+
349+
data_dict['spatial_features_2d'] = x
350+
351+
return data_dict
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .height_compression import HeightCompression
2-
from .pointpillar_scatter import PointPillarScatter
2+
from .pointpillar_scatter import PointPillarScatter, PointPillarScatter3d
33
from .conv2d_collapse import Conv2DCollapse
44

55
__all__ = {
66
'HeightCompression': HeightCompression,
77
'PointPillarScatter': PointPillarScatter,
8-
'Conv2DCollapse': Conv2DCollapse
8+
'Conv2DCollapse': Conv2DCollapse,
9+
'PointPillarScatter3d': PointPillarScatter3d,
910
}

pcdet/models/backbones_2d/map_to_bev/pointpillar_scatter.py

+36
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,39 @@ def forward(self, batch_dict, **kwargs):
3535
batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features * self.nz, self.ny, self.nx)
3636
batch_dict['spatial_features'] = batch_spatial_features
3737
return batch_dict
38+
39+
40+
class PointPillarScatter3d(nn.Module):
41+
def __init__(self, model_cfg, grid_size, **kwargs):
42+
super().__init__()
43+
44+
self.model_cfg = model_cfg
45+
self.nx, self.ny, self.nz = self.model_cfg.INPUT_SHAPE
46+
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
47+
self.num_bev_features_before_compression = self.model_cfg.NUM_BEV_FEATURES // self.nz
48+
49+
def forward(self, batch_dict, **kwargs):
50+
pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords']
51+
52+
batch_spatial_features = []
53+
batch_size = coords[:, 0].max().int().item() + 1
54+
for batch_idx in range(batch_size):
55+
spatial_feature = torch.zeros(
56+
self.num_bev_features_before_compression,
57+
self.nz * self.nx * self.ny,
58+
dtype=pillar_features.dtype,
59+
device=pillar_features.device)
60+
61+
batch_mask = coords[:, 0] == batch_idx
62+
this_coords = coords[batch_mask, :]
63+
indices = this_coords[:, 1] * self.ny * self.nx + this_coords[:, 2] * self.nx + this_coords[:, 3]
64+
indices = indices.type(torch.long)
65+
pillars = pillar_features[batch_mask, :]
66+
pillars = pillars.t()
67+
spatial_feature[:, indices] = pillars
68+
batch_spatial_features.append(spatial_feature)
69+
70+
batch_spatial_features = torch.stack(batch_spatial_features, 0)
71+
batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features_before_compression * self.nz, self.ny, self.nx)
72+
batch_dict['spatial_features'] = batch_spatial_features
73+
return batch_dict

pcdet/models/backbones_3d/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt
66
from .spconv_backbone_voxelnext2d import VoxelResBackBone8xVoxelNeXt2D
77
from .spconv_unet import UNetV2
8+
from .dsvt import DSVT
89

910
__all__ = {
1011
'VoxelBackBone8x': VoxelBackBone8x,
@@ -16,5 +17,6 @@
1617
'VoxelResBackBone8xVoxelNeXt': VoxelResBackBone8xVoxelNeXt,
1718
'VoxelResBackBone8xVoxelNeXt2D': VoxelResBackBone8xVoxelNeXt2D,
1819
'PillarBackBone8x': PillarBackBone8x,
19-
'PillarRes18BackBone8x': PillarRes18BackBone8x
20+
'PillarRes18BackBone8x': PillarRes18BackBone8x,
21+
'DSVT': DSVT,
2022
}

0 commit comments

Comments
 (0)