Skip to content

Commit ed2bb81

Browse files
committed
Update configs
1 parent 123d42e commit ed2bb81

File tree

6 files changed

+390
-391
lines changed

6 files changed

+390
-391
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
2424
## Changelog
2525
[2023-05-xx] Added support for the multi-modal 3D object detection model [`BEVFusion`](https://arxiv.org/abs/2205.13542) on Nuscenes dataset, which fuses multi-modal information on BEV space and reaches 70.98% NDS on Nuscenes validation dataset. (see the [guideline](docs/guidelines_of_approaches/bevfusion.md) on how to train/test with BEVFusion).
2626
* Support multi-modal Nuscenes detection (See the [GETTING_STARTED.md](docs/GETTING_STARTED.md) to process data).
27-
* Support TransFusion-Lidar head, which ahcieves 69.43% NDS on Nuscenes validation dataset.
27+
* Support [TransFusion-Lidar](https://arxiv.org/abs/2203.11496) head, which ahcieves 69.43% NDS on Nuscenes validation dataset.
2828

2929
[2023-04-02] Added support for [`VoxelNeXt`](https://github.com/dvlab-research/VoxelNeXt) on Nuscenes, Waymo, and Argoverse2 datasets. It is a fully sparse 3D object detection network, which is a clean sparse CNNs network and predicts 3D objects directly upon voxels.
3030

@@ -213,8 +213,8 @@ All models are trained with 8 GPUs and are available for download. For training
213213
| [CenterPoint (voxel_size=0.1)](tools/cfgs/nuscenes_models/cbgs_voxel01_res3d_centerpoint.yaml) | 30.11 | 25.55 | 38.28 | 21.94 | 18.87 | 56.03 | 64.54 | [model-34M](https://drive.google.com/file/d/1Cz-J1c3dw7JAWc25KRG1XQj8yCaOlexQ/view?usp=sharing) |
214214
| [CenterPoint (voxel_size=0.075)](tools/cfgs/nuscenes_models/cbgs_voxel0075_res3d_centerpoint.yaml) | 28.80 | 25.43 | 37.27 | 21.55 | 18.24 | 59.22 | 66.48 | [model-34M](https://drive.google.com/file/d/1XOHAWm1MPkCKr1gqmc3TWi5AYZgPsgxU/view?usp=sharing) |
215215
| [VoxelNeXt (voxel_size=0.075)](tools/cfgs/nuscenes_models/cbgs_voxel0075_voxelnext.yaml) | 30.11 | 25.23 | 40.57 | 21.69 | 18.56 | 60.53 | 66.65 | [model-31M](https://drive.google.com/file/d/1IV7e7G9X-61KXSjMGtQo579pzDNbhwvf/view?usp=share_link) |
216-
| [TransFusion-L*](tools/cfgs/nuscenes_models/cbgs_transfusion_lidar.yaml) | 27.96 | 25.37 | 29.35 | 27.31 | 18.55 | 64.58 | 69.43 | [model-32M](https://drive.google.com/file/d/1cuZ2qdDnxSwTCsiXWwbqCGF-uoazTXbz/view?usp=share_link) |
217-
| [BEVFusion](tools/cfgs/nuscenes_models/cbgs_bevfusion.yaml) | 28.03 | 25.43 | 30.19 | 26.76 | 18.48 | 67.75 | 70.98 | [model-157M](https://drive.google.com/file/d/1X50b-8immqlqD8VPAUkSKI0Ls-4k37g9/view?usp=share_link) |
216+
| [TransFusion-L*](tools/cfgs/nuscenes_models/transfusion_lidar.yaml) | 27.96 | 25.37 | 29.35 | 27.31 | 18.55 | 64.58 | 69.43 | [model-32M](https://drive.google.com/file/d/1cuZ2qdDnxSwTCsiXWwbqCGF-uoazTXbz/view?usp=share_link) |
217+
| [BEVFusion](tools/cfgs/nuscenes_models/bevfusion.yaml) | 28.03 | 25.43 | 30.19 | 26.76 | 18.48 | 67.75 | 70.98 | [model-157M](https://drive.google.com/file/d/1X50b-8immqlqD8VPAUkSKI0Ls-4k37g9/view?usp=share_link) |
218218

219219
*: Use the fade strategy, which disables data augmentations in the last several epochs during training.
220220

docs/guidelines_of_approaches/bevfusion.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,25 @@ Please refer to [GETTING_STARTED.md](../GETTING_STARTED.md) to process the multi
1111

1212
1. Train the lidar branch for BEVFusion:
1313
```shell
14-
bash scripts/dist_train.sh ${NUM_GPUS} --cfg_file cfgs/nuscenes_models/cbgs_transfusion_lidar.yaml \
14+
bash scripts/dist_train.sh ${NUM_GPUS} --cfg_file cfgs/nuscenes_models/transfusion_lidar.yaml \
1515
```
16-
The ckpt will be saved in ../output/nuscenes_models/cbgs_transfusion_lidar/default/ckpt, or you can download pretrained checkpoint directly form [here](https://drive.google.com/file/d/1cuZ2qdDnxSwTCsiXWwbqCGF-uoazTXbz/view?usp=share_link).
16+
The ckpt will be saved in ../output/nuscenes_models/transfusion_lidar/default/ckpt, or you can download pretrained checkpoint directly form [here](https://drive.google.com/file/d/1cuZ2qdDnxSwTCsiXWwbqCGF-uoazTXbz/view?usp=share_link).
1717

18-
1. To train BEVFusion, you need to download pretrained parameters for image backbone [here](www.google.com), and specify the path in [config](../../tools/cfgs/nuscenes_models/cbgs_bevfusion.yaml#L88). Then run the following command:
18+
2. To train BEVFusion, you need to download pretrained parameters for image backbone [here](https://drive.google.com/file/d/1v74WCt4_5ubjO7PciA5T0xhQc9bz_jZu/view?usp=share_link), and specify the path in [config](../../tools/cfgs/nuscenes_models/bevfusion.yaml#L88). Then run the following command:
1919
```shell
20-
bash scripts/dist_train.sh ${NUM_GPUS} --cfg_file cfgs/nuscenes_models/cbgs_bevfusion.yaml \
20+
bash scripts/dist_train.sh ${NUM_GPUS} --cfg_file cfgs/nuscenes_models/bevfusion.yaml \
2121
--pretrained_model path_to_pretrained_lidar_branch_ckpt \
2222
```
2323
## Evaluation
2424
* Test with a pretrained model:
2525
```shell
26-
bash scripts/dist_test.sh ${NUM_GPUS} --cfg_file cfgs/nuscenes_models/cbgs_bevfusion.yaml \
27-
--ckpt ../output/cfgs/nuscenes_models/cbgs_bevfusion/default/ckpt/checkpoint_epoch_6.pth
26+
bash scripts/dist_test.sh ${NUM_GPUS} --cfg_file cfgs/nuscenes_models/bevfusion.yaml \
27+
--ckpt ../output/cfgs/nuscenes_models/bevfusion/default/ckpt/checkpoint_epoch_6.pth
2828
```
2929

3030
## Performance
3131
All models are trained with spconv 1.0, but you can directly load them for testing regardless of the spconv version.
3232
| | mATE | mASE | mAOE | mAVE | mAAE | mAP | NDS | download |
3333
|----------------------------------------------------------------------------------------------------|-------:|:------:|:------:|:-----:|:-----:|:-----:|:------:|:--------------------------------------------------------------------------------------------------:|
34-
| [TransFusion-L](../../tools/cfgs/nuscenes_models/cbgs_transfusion_lidar.yaml) | 27.96 | 25.37 | 29.35 | 27.31 | 18.55 | 64.58 | 69.43 | [model-32M](https://drive.google.com/file/d/1cuZ2qdDnxSwTCsiXWwbqCGF-uoazTXbz/view?usp=share_link) |
35-
| [BEVFusion](../../tools/cfgs/nuscenes_models/cbgs_bevfusion.yaml) | 28.03 | 25.43 | 30.19 | 26.76 | 18.48 | 67.75 | 70.98 | [model-157M](https://drive.google.com/file/d/1X50b-8immqlqD8VPAUkSKI0Ls-4k37g9/view?usp=share_link) |
34+
| [TransFusion-L](../../tools/cfgs/nuscenes_models/transfusion_lidar.yaml) | 27.96 | 25.37 | 29.35 | 27.31 | 18.55 | 64.58 | 69.43 | [model-32M](https://drive.google.com/file/d/1cuZ2qdDnxSwTCsiXWwbqCGF-uoazTXbz/view?usp=share_link) |
35+
| [BEVFusion](../../tools/cfgs/nuscenes_models/bevfusion.yaml) | 28.03 | 25.43 | 30.19 | 26.76 | 18.48 | 67.75 | 70.98 | [model-157M](https://drive.google.com/file/d/1X50b-8immqlqD8VPAUkSKI0Ls-4k37g9/view?usp=share_link) |
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
CLASS_NAMES: ['car','truck', 'construction_vehicle', 'bus', 'trailer',
2+
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
3+
4+
DATA_CONFIG:
5+
_BASE_CONFIG_: cfgs/dataset_configs/nuscenes_dataset.yaml
6+
POINT_CLOUD_RANGE: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]
7+
CAMERA_CONFIG:
8+
USE_CAMERA: True
9+
IMAGE:
10+
FINAL_DIM: [256,704]
11+
RESIZE_LIM_TRAIN: [0.38, 0.55]
12+
RESIZE_LIM_TEST: [0.48, 0.48]
13+
14+
DATA_AUGMENTOR:
15+
DISABLE_AUG_LIST: ['placeholder']
16+
AUG_CONFIG_LIST:
17+
- NAME: random_world_flip
18+
ALONG_AXIS_LIST: ['x', 'y']
19+
20+
- NAME: random_world_rotation
21+
WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
22+
23+
- NAME: random_world_scaling
24+
WORLD_SCALE_RANGE: [0.9, 1.1]
25+
26+
- NAME: random_world_translation
27+
NOISE_TRANSLATE_STD: [0.5, 0.5, 0.5]
28+
29+
- NAME: imgaug
30+
ROT_LIM: [-5.4, 5.4]
31+
RAND_FLIP: True
32+
33+
DATA_PROCESSOR:
34+
- NAME: mask_points_and_boxes_outside_range
35+
REMOVE_OUTSIDE_BOXES: True
36+
37+
- NAME: shuffle_points
38+
SHUFFLE_ENABLED: {
39+
'train': True,
40+
'test': True
41+
}
42+
43+
- NAME: transform_points_to_voxels
44+
VOXEL_SIZE: [0.075, 0.075, 0.2]
45+
MAX_POINTS_PER_VOXEL: 10
46+
MAX_NUMBER_OF_VOXELS: {
47+
'train': 120000,
48+
'test': 160000
49+
}
50+
51+
- NAME: image_calibrate
52+
53+
- NAME: image_normalize
54+
mean: [0.485, 0.456, 0.406]
55+
std: [0.229, 0.224, 0.225]
56+
57+
58+
MODEL:
59+
NAME: BevFusion
60+
61+
VFE:
62+
NAME: MeanVFE
63+
64+
BACKBONE_3D:
65+
NAME: VoxelResBackBone8x
66+
USE_BIAS: False
67+
68+
MAP_TO_BEV:
69+
NAME: HeightCompression
70+
NUM_BEV_FEATURES: 256
71+
72+
IMAGE_BACKBONE:
73+
NAME: SwinTransformer
74+
EMBED_DIMS: 96
75+
DEPTHS: [2, 2, 6, 2]
76+
NUM_HEADS: [3, 6, 12, 24]
77+
WINDOW_SIZE: 7
78+
MLP_RATIO: 4
79+
DROP_RATE: 0.
80+
ATTN_DROP_RATE: 0.
81+
DROP_PATH_RATE: 0.2
82+
PATCH_NORM: True
83+
OUT_INDICES: [1, 2, 3]
84+
WITH_CP: False
85+
CONVERT_WEIGHTS: True
86+
INIT_CFG:
87+
type: Pretrained
88+
checkpoint: swint-nuimages-pretrained.pth
89+
90+
NECK:
91+
NAME: GeneralizedLSSFPN
92+
IN_CHANNELS: [192, 384, 768]
93+
OUT_CHANNELS: 256
94+
START_LEVEL: 0
95+
END_LEVEL: -1
96+
NUM_OUTS: 3
97+
98+
VTRANSFORM:
99+
NAME: DepthLSSTransform
100+
IMAGE_SIZE: [256, 704]
101+
IN_CHANNEL: 256
102+
OUT_CHANNEL: 80
103+
FEATURE_SIZE: [32, 88]
104+
XBOUND: [-54.0, 54.0, 0.3]
105+
YBOUND: [-54.0, 54.0, 0.3]
106+
ZBOUND: [-10.0, 10.0, 20.0]
107+
DBOUND: [1.0, 60.0, 0.5]
108+
DOWNSAMPLE: 2
109+
110+
FUSER:
111+
NAME: ConvFuser
112+
IN_CHANNEL: 336
113+
OUT_CHANNEL: 256
114+
115+
BACKBONE_2D:
116+
NAME: BaseBEVBackbone
117+
LAYER_NUMS: [5, 5]
118+
LAYER_STRIDES: [1, 2]
119+
NUM_FILTERS: [128, 256]
120+
UPSAMPLE_STRIDES: [1, 2]
121+
NUM_UPSAMPLE_FILTERS: [256, 256]
122+
USE_CONV_FOR_NO_STRIDE: True
123+
124+
125+
DENSE_HEAD:
126+
CLASS_AGNOSTIC: False
127+
NAME: TransFusionHead
128+
129+
USE_BIAS_BEFORE_NORM: False
130+
131+
NUM_PROPOSALS: 200
132+
HIDDEN_CHANNEL: 128
133+
NUM_CLASSES: 10
134+
NUM_HEADS: 8
135+
NMS_KERNEL_SIZE: 3
136+
FFN_CHANNEL: 256
137+
DROPOUT: 0.1
138+
BN_MOMENTUM: 0.1
139+
ACTIVATION: relu
140+
141+
NUM_HM_CONV: 2
142+
SEPARATE_HEAD_CFG:
143+
HEAD_ORDER: ['center', 'height', 'dim', 'rot', 'vel']
144+
HEAD_DICT: {
145+
'center': {'out_channels': 2, 'num_conv': 2},
146+
'height': {'out_channels': 1, 'num_conv': 2},
147+
'dim': {'out_channels': 3, 'num_conv': 2},
148+
'rot': {'out_channels': 2, 'num_conv': 2},
149+
'vel': {'out_channels': 2, 'num_conv': 2},
150+
}
151+
152+
TARGET_ASSIGNER_CONFIG:
153+
FEATURE_MAP_STRIDE: 8
154+
DATASET: nuScenes
155+
GAUSSIAN_OVERLAP: 0.1
156+
MIN_RADIUS: 2
157+
HUNGARIAN_ASSIGNER:
158+
cls_cost: {'gamma': 2.0, 'alpha': 0.25, 'weight': 0.15}
159+
reg_cost: {'weight': 0.25}
160+
iou_cost: {'weight': 0.25}
161+
162+
LOSS_CONFIG:
163+
LOSS_WEIGHTS: {
164+
'cls_weight': 1.0,
165+
'bbox_weight': 0.25,
166+
'hm_weight': 1.0,
167+
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
168+
}
169+
LOSS_CLS:
170+
use_sigmoid: True
171+
gamma: 2.0
172+
alpha: 0.25
173+
174+
POST_PROCESSING:
175+
SCORE_THRESH: 0.0
176+
POST_CENTER_RANGE: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0]
177+
178+
POST_PROCESSING:
179+
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
180+
SCORE_THRESH: 0.1
181+
OUTPUT_RAW_SCORE: False
182+
183+
EVAL_METRIC: kitti
184+
185+
186+
187+
OPTIMIZATION:
188+
BATCH_SIZE_PER_GPU: 3
189+
NUM_EPOCHS: 6
190+
191+
OPTIMIZER: adam_cosineanneal
192+
LR: 0.0001
193+
WEIGHT_DECAY: 0.01
194+
MOMENTUM: 0.9
195+
BETAS: [0.9, 0.999]
196+
197+
MOMS: [0.9, 0.8052631]
198+
PCT_START: 0.4
199+
WARMUP_ITER: 500
200+
201+
DECAY_STEP_LIST: [35, 45]
202+
LR_WARMUP: False
203+
WARMUP_EPOCH: 1
204+
205+
GRAD_NORM_CLIP: 35
206+
207+
LOSS_SCALE_FP16: 32
208+

0 commit comments

Comments
 (0)