diff --git a/demo/visualization_demo/bev_vis_multi_frame_demo.py b/demo/visualization_demo/bev_vis_multi_frame_demo.py new file mode 100644 index 00000000..69ec8bf8 --- /dev/null +++ b/demo/visualization_demo/bev_vis_multi_frame_demo.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import numpy as np + +from paddle3d.apis.infer import Infer +from paddle3d.apis.config import Config +from paddle3d.slim import get_qat_config +from paddle3d.utils.checkpoint import load_pretrained_model + + +def parse_args(): + """ + """ + parser = argparse.ArgumentParser(description='Model evaluation') + # params of training + parser.add_argument( + "--config", dest="cfg", help="The config file.", default=None, type=str) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size of one gpu or cpu', + type=int, + default=None) + parser.add_argument( + '--model', + dest='model', + help='pretrained parameters of the model', + type=str, + default=None) + parser.add_argument( + '--num_workers', + dest='num_workers', + help='Num workers for data loader', + type=int, + default=2) + parser.add_argument( + '--quant_config', + dest='quant_config', + help='Config for quant model.', + default=None, + type=str) + + return parser.parse_args() + + +def worker_init_fn(worker_id): + np.random.seed(1024) + + +def main(args): + """ + """ + if args.cfg is None: + raise RuntimeError("No configuration file specified!") + + if not os.path.exists(args.cfg): + raise RuntimeError("Config file `{}` does not exist!".format(args.cfg)) + + cfg = Config(path=args.cfg, batch_size=args.batch_size) + + if cfg.val_dataset is None: + raise RuntimeError( + 'The validation dataset is not specified in the configuration file!' + ) + elif len(cfg.val_dataset) == 0: + raise ValueError( + 'The length of validation dataset is 0. Please check if your dataset is valid!' + ) + + dic = cfg.to_dict() + batch_size = dic.pop('batch_size') + dic.update({ + 'dataloader_fn': { + 'batch_size': batch_size, + 'num_workers': args.num_workers, + 'worker_init_fn': worker_init_fn + } + }) + + if args.quant_config: + quant_config = get_qat_config(args.quant_config) + cfg.model.build_slim_model(quant_config['quant_config']) + + if args.model is not None: + load_pretrained_model(cfg.model, args.model) + dic['checkpoint'] = None + dic['resume'] = False + else: + dic['resume'] = True + + infer = Infer(**dic) + infer.infer('bev') + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/demo/visualization_demo/bev_vis_single_frame_demo.py b/demo/visualization_demo/bev_vis_single_frame_demo.py new file mode 100644 index 00000000..f28aa04d --- /dev/null +++ b/demo/visualization_demo/bev_vis_single_frame_demo.py @@ -0,0 +1,188 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import numpy as np +import paddle +from paddle.inference import Config, create_predictor +from paddle3d.ops.iou3d_nms_cuda import nms_gpu +from demo.visualization_demo.vis_utils import preprocess, Calibration, show_bev_with_boxes + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_file", + type=str, + help="Model filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + "--params_file", + type=str, + help= + "Parameter filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + '--lidar_file', type=str, help='The lidar path.', required=True) + parser.add_argument( + '--calib_file', type=str, help='The lidar path.', required=True) + parser.add_argument( + "--num_point_dim", + type=int, + default=4, + help="Dimension of a point in the lidar file.") + parser.add_argument( + "--point_cloud_range", + dest='point_cloud_range', + nargs='+', + help="Range of point cloud for voxelize operation.", + type=float, + default=None) + parser.add_argument( + "--voxel_size", + dest='voxel_size', + nargs='+', + help="Size of voxels for voxelize operation.", + type=float, + default=None) + parser.add_argument( + "--max_points_in_voxel", + type=int, + default=100, + help="Maximum number of points in a voxel.") + parser.add_argument( + "--max_voxel_num", + type=int, + default=12000, + help="Maximum number of voxels.") + parser.add_argument("--gpu_id", type=int, default=0, help="GPU card id.") + parser.add_argument( + "--use_trt", + type=int, + default=0, + help="Whether to use tensorrt to accelerate when using gpu.") + parser.add_argument( + "--trt_precision", + type=int, + default=0, + help="Precision type of tensorrt, 0: kFloat32, 1: kHalf.") + parser.add_argument( + "--trt_use_static", + type=int, + default=0, + help="Whether to load the tensorrt graph optimization from a disk path." + ) + parser.add_argument( + "--trt_static_dir", + type=str, + help="Path of a tensorrt graph optimization directory.") + parser.add_argument( + "--collect_shape_info", + type=int, + default=0, + help="Whether to collect dynamic shape before using tensorrt.") + parser.add_argument( + "--dynamic_shape_file", + type=str, + default="", + help="Path of a dynamic shape file for tensorrt.") + + return parser.parse_args() + + +def init_predictor(model_file, + params_file, + gpu_id=0, + use_trt=False, + trt_precision=0, + trt_use_static=False, + trt_static_dir=None, + collect_shape_info=False, + dynamic_shape_file=None): + config = Config(model_file, params_file) + config.enable_memory_optim() + config.enable_use_gpu(1000, gpu_id) + if use_trt: + precision_mode = paddle.inference.PrecisionType.Float32 + if trt_precision == 1: + precision_mode = paddle.inference.PrecisionType.Half + config.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=1, + min_subgraph_size=10, + precision_mode=precision_mode, + use_static=trt_use_static, + use_calib_mode=False) + if collect_shape_info: + config.collect_shape_range_info(dynamic_shape_file) + else: + config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, True) + if trt_use_static: + config.set_optim_cache_dir(trt_static_dir) + + predictor = create_predictor(config) + return predictor + + +def run(predictor, voxels, coords, num_points_per_voxel): + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + if name == "voxels": + input_tensor.reshape(voxels.shape) + input_tensor.copy_from_cpu(voxels.copy()) + elif name == "coords": + input_tensor.reshape(coords.shape) + input_tensor.copy_from_cpu(coords.copy()) + elif name == "num_points_per_voxel": + input_tensor.reshape(num_points_per_voxel.shape) + input_tensor.copy_from_cpu(num_points_per_voxel.copy()) + + # do the inference + predictor.run() + + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + if i == 0: + box3d_lidar = output_tensor.copy_to_cpu() + elif i == 1: + label_preds = output_tensor.copy_to_cpu() + elif i == 2: + scores = output_tensor.copy_to_cpu() + return box3d_lidar, label_preds, scores + + +if __name__ == '__main__': + args = parse_args() + + predictor = init_predictor(args.model_file, args.params_file, args.gpu_id, + args.use_trt, args.trt_precision, + args.trt_use_static, args.trt_static_dir, + args.collect_shape_info, args.dynamic_shape_file) + voxels, coords, num_points_per_voxel = preprocess( + args.lidar_file, args.num_point_dim, args.point_cloud_range, + args.voxel_size, args.max_points_in_voxel, args.max_voxel_num) + box3d_lidar, label_preds, scores = run(predictor, voxels, coords, + num_points_per_voxel) + + scan = np.fromfile(args.lidar_file, dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + + # Obtain calibration information about Kitti + calib = Calibration(args.calib_file) + + # Plot box in lidar cloud + show_bev_with_boxes(pc_velo, box3d_lidar, scores, calib) diff --git a/demo/visualization_demo/dataset_vis_demo.py b/demo/visualization_demo/dataset_vis_demo.py new file mode 100644 index 00000000..2079590a --- /dev/null +++ b/demo/visualization_demo/dataset_vis_demo.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import cv2 + +from paddle3d.datasets.kitti.kitti_utils import camera_record_to_object + +from demo.visualization_demo.vis_utils import Calibration, show_lidar_with_boxes, total_imgpred_by_conf_to_kitti_records, \ + make_imgpts_list, draw_mono_3d, show_bev_with_boxes + +pth = '../datasets/KITTI/training' # Kitti dataset path + +files = os.listdir(os.path.join(pth, 'image_2')) +files = sorted(files) + +mode = 'bev' + +assert mode in ['bev', 'image', 'pcd'], '' + +for img in files: + id = img[:-4] + label_file = os.path.join(pth, 'label_2', f'{id}.txt') + calib_file = os.path.join(pth, 'calib', f'{id}.txt') + img_file = os.path.join(pth, 'image_2', f'{id}.png') + pcd_file = os.path.join(pth, 'velodyne', f'{id}.bin') + + label_lines = open(label_file).readlines() + kitti_records_list = [line.strip().split(' ') for line in label_lines] + + if mode == 'pcd': + box3d_list = [] + for itm in kitti_records_list: + itm = [float(i) for i in itm[8:]] + # [z, -x, -y, w, l, h, ry] + box3d_list.append( + [itm[5], -itm[3], -itm[4], itm[1], itm[2], itm[0], itm[6]]) + box3d = np.asarray(box3d_list) + scan = np.fromfile(pcd_file, dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + # Obtain calibration information about Kitti + calib = Calibration(calib_file) + # Plot box in lidar cloud + # show_lidar_with_boxes(pc_velo, result['bboxes_3d'], result['confidences'], calib) + show_lidar_with_boxes(pc_velo, box3d, np.ones(box3d.shape[0]), calib) + + if mode == 'image': + kitti_records = np.array(kitti_records_list) + bboxes_2d, bboxes_3d, labels = camera_record_to_object(kitti_records) + # read origin image + img_origin = cv2.imread(img_file) + # to 8 points on image + itms = open(calib_file).readlines()[2] + P2 = itms[4:].strip().split(' ') + K = np.asarray([float(i) for i in P2]).reshape(3, 4)[:, :3] + imgpts_list = make_imgpts_list(bboxes_3d, K) + # draw smoke result to photo + draw_mono_3d(img_origin, imgpts_list) + + if mode == 'bev': + box3d_list = [] + for itm in kitti_records_list: + itm = [float(i) for i in itm[8:]] + # [z, -x, -y, w, l, h, ry] + box3d_list.append( + [itm[5], -itm[3], -itm[4], itm[1], itm[2], itm[0], itm[6]]) + box3d = np.asarray(box3d_list) + scan = np.fromfile(pcd_file, dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + # Obtain calibration information about Kitti + calib = Calibration(calib_file) + # Plot box in lidar cloud (bev) + show_bev_with_boxes(pc_velo, box3d, np.ones(box3d.shape[0]), calib) diff --git a/demo/visualization_demo/img/bev.png b/demo/visualization_demo/img/bev.png new file mode 100644 index 00000000..88a6bde0 Binary files /dev/null and b/demo/visualization_demo/img/bev.png differ diff --git a/demo/visualization_demo/img/mono.jpg b/demo/visualization_demo/img/mono.jpg new file mode 100644 index 00000000..c2eb7348 Binary files /dev/null and b/demo/visualization_demo/img/mono.jpg differ diff --git a/demo/visualization_demo/img/pc.png b/demo/visualization_demo/img/pc.png new file mode 100644 index 00000000..2b101bac Binary files /dev/null and b/demo/visualization_demo/img/pc.png differ diff --git a/demo/visualization_demo/mono_vis_multi_frame_demo.py b/demo/visualization_demo/mono_vis_multi_frame_demo.py new file mode 100644 index 00000000..6f78e3b7 --- /dev/null +++ b/demo/visualization_demo/mono_vis_multi_frame_demo.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import cv2 +import numpy as np + +from demo.visualization_demo.vis_utils import make_imgpts_list, draw_mono_3d, total_imgpred_by_conf_to_kitti_records + +from paddle3d.apis.infer import Infer +from paddle3d.apis.config import Config +from paddle3d.slim import get_qat_config +from paddle3d.utils.checkpoint import load_pretrained_model +from paddle3d.datasets.kitti.kitti_utils import camera_record_to_object + + +def parse_args(): + """ + """ + parser = argparse.ArgumentParser(description='Model evaluation') + # params of training + parser.add_argument( + "--config", dest="cfg", help="The config file.", default=None, type=str) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size of one gpu or cpu', + type=int, + default=None) + parser.add_argument( + '--model', + dest='model', + help='pretrained parameters of the model', + type=str, + default=None) + parser.add_argument( + '--num_workers', + dest='num_workers', + help='Num workers for data loader', + type=int, + default=2) + parser.add_argument( + '--quant_config', + dest='quant_config', + help='Config for quant model.', + default=None, + type=str) + + return parser.parse_args() + + +def worker_init_fn(worker_id): + np.random.seed(1024) + + +def main(args): + """ + """ + if args.cfg is None: + raise RuntimeError("No configuration file specified!") + + if not os.path.exists(args.cfg): + raise RuntimeError("Config file `{}` does not exist!".format(args.cfg)) + + cfg = Config(path=args.cfg, batch_size=args.batch_size) + + if cfg.val_dataset is None: + raise RuntimeError( + 'The validation dataset is not specified in the configuration file!' + ) + elif len(cfg.val_dataset) == 0: + raise ValueError( + 'The length of validation dataset is 0. Please check if your dataset is valid!' + ) + + dic = cfg.to_dict() + batch_size = dic.pop('batch_size') + dic.update({ + 'dataloader_fn': { + 'batch_size': batch_size, + 'num_workers': args.num_workers, + 'worker_init_fn': worker_init_fn + } + }) + + if args.quant_config: + quant_config = get_qat_config(args.quant_config) + cfg.model.build_slim_model(quant_config['quant_config']) + + if args.model is not None: + load_pretrained_model(cfg.model, args.model) + dic['checkpoint'] = None + dic['resume'] = False + else: + dic['resume'] = True + + infer = Infer(**dic) + infer.infer('image') + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/demo/visualization_demo/mono_vis_single_frame_demo.py b/demo/visualization_demo/mono_vis_single_frame_demo.py new file mode 100644 index 00000000..55b11a97 --- /dev/null +++ b/demo/visualization_demo/mono_vis_single_frame_demo.py @@ -0,0 +1,137 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import cv2 +import numpy as np + +from paddle.inference import Config, PrecisionType, create_predictor +from paddle3d.datasets.kitti.kitti_utils import camera_record_to_object +from demo.visualization_demo.vis_utils import get_img, get_ratio, total_pred_by_conf_to_kitti_records, make_imgpts_list, draw_mono_3d + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_file", + type=str, + help="Model filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + "--params_file", + type=str, + help= + "Parameter filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + '--image', dest='image', help='The image path', type=str, required=True) + parser.add_argument( + "--use_gpu", action='store_true', help="Whether use gpu.") + parser.add_argument( + "--use_trt", action='store_true', help="Whether use trt.") + parser.add_argument( + "--collect_dynamic_shape_info", + action='store_true', + help="Whether to collect dynamic shape before using tensorrt.") + parser.add_argument( + "--dynamic_shape_file", + dest='dynamic_shape_file', + help='The image path', + type=str, + default="dynamic_shape_info.txt") + return parser.parse_args() + + +def init_predictor(args): + config = Config(args.model_file, args.params_file) + config.enable_memory_optim() + if args.use_gpu: + config.enable_use_gpu(1000, 0) + else: + # If not specific mkldnn, you can set the blas thread. + # The thread num should not be greater than the number of cores in the CPU. + config.set_cpu_math_library_num_threads(4) + config.enable_mkldnn() + + if args.collect_dynamic_shape_info: + config.collect_shape_range_info(args.dynamic_shape_file) + elif args.use_trt: + allow_build_at_runtime = True + config.enable_tuned_tensorrt_dynamic_shape(args.dynamic_shape_file, + allow_build_at_runtime) + + config.enable_tensorrt_engine( + workspace_size=1 << 20, + max_batch_size=1, + min_subgraph_size=3, + precision_mode=PrecisionType.Float32) + + predictor = create_predictor(config) + return predictor + + +def run(predictor, image, K, down_ratio): + # copy img data to input tensor + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + if name == "images": + input_tensor.reshape(image.shape) + input_tensor.copy_from_cpu(image.copy()) + elif name == "trans_cam_to_img": + input_tensor.reshape(K.shape) + input_tensor.copy_from_cpu(K.copy()) + elif name == "down_ratios": + input_tensor.reshape(down_ratio.shape) + input_tensor.copy_from_cpu(down_ratio.copy()) + + # do the inference + predictor.run() + + results = [] + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + output_data = output_tensor.copy_to_cpu() + results.append(output_data) + + return results + + +if __name__ == '__main__': + args = parse_args() + pred = init_predictor(args) + # Listed below are camera intrinsic parameter of the kitti dataset + # If the model is trained on other datasets, please replace the relevant data + K = np.array([[[721.53771973, 0., 609.55932617], + [0., 721.53771973, 172.85400391], [0, 0, 1]]], np.float32) + + img, ori_img_size, output_size = get_img(args.image) + ratio = get_ratio(ori_img_size, output_size) + + results = run(pred, img, K, ratio) + + total_pred = results[0] + print(total_pred) + # convert pred to bboxes_2d, bboxes_3d + kitti_records = total_pred_by_conf_to_kitti_records(total_pred, conf=0.5) + bboxes_2d, bboxes_3d, labels = camera_record_to_object(kitti_records) + # read origin image + img_origin = cv2.imread(args.image) + # to 8 points on image + imgpts_list = make_imgpts_list(bboxes_3d, K[0]) + # draw smoke result to photo + draw_mono_3d(img_origin, imgpts_list) diff --git a/demo/visualization_demo/pcd_vis_multi_frame_demo.py b/demo/visualization_demo/pcd_vis_multi_frame_demo.py new file mode 100644 index 00000000..9fee1af0 --- /dev/null +++ b/demo/visualization_demo/pcd_vis_multi_frame_demo.py @@ -0,0 +1,113 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import cv2 +import numpy as np + +from paddle3d.apis.infer import Infer +from paddle3d.apis.config import Config +from paddle3d.slim import get_qat_config +from paddle3d.utils.checkpoint import load_pretrained_model +from paddle3d.datasets.kitti.kitti_utils import camera_record_to_object + + +def parse_args(): + """ + """ + parser = argparse.ArgumentParser(description='Model evaluation') + # params of training + parser.add_argument( + "--config", dest="cfg", help="The config file.", default=None, type=str) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size of one gpu or cpu', + type=int, + default=None) + parser.add_argument( + '--model', + dest='model', + help='pretrained parameters of the model', + type=str, + default=None) + parser.add_argument( + '--num_workers', + dest='num_workers', + help='Num workers for data loader', + type=int, + default=2) + parser.add_argument( + '--quant_config', + dest='quant_config', + help='Config for quant model.', + default=None, + type=str) + + return parser.parse_args() + + +def worker_init_fn(worker_id): + np.random.seed(1024) + + +def main(args): + """ + """ + if args.cfg is None: + raise RuntimeError("No configuration file specified!") + + if not os.path.exists(args.cfg): + raise RuntimeError("Config file `{}` does not exist!".format(args.cfg)) + + cfg = Config(path=args.cfg, batch_size=args.batch_size) + print(args.cfg) + if cfg.val_dataset is None: + raise RuntimeError( + 'The validation dataset is not specified in the configuration file!' + ) + elif len(cfg.val_dataset) == 0: + raise ValueError( + 'The length of validation dataset is 0. Please check if your dataset is valid!' + ) + + dic = cfg.to_dict() + batch_size = dic.pop('batch_size') + dic.update({ + 'dataloader_fn': { + 'batch_size': batch_size, + 'num_workers': args.num_workers, + 'worker_init_fn': worker_init_fn + } + }) + + if args.quant_config: + quant_config = get_qat_config(args.quant_config) + cfg.model.build_slim_model(quant_config['quant_config']) + + if args.model is not None: + load_pretrained_model(cfg.model, args.model) + dic['checkpoint'] = None + dic['resume'] = False + else: + dic['resume'] = True + + infer = Infer(**dic) + infer.infer('pcd') + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/demo/visualization_demo/pcd_vis_single_frame_demo.py b/demo/visualization_demo/pcd_vis_single_frame_demo.py new file mode 100644 index 00000000..501061ba --- /dev/null +++ b/demo/visualization_demo/pcd_vis_single_frame_demo.py @@ -0,0 +1,188 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import numpy as np +import paddle +from paddle.inference import Config, create_predictor +from paddle3d.ops.iou3d_nms_cuda import nms_gpu +from demo.visualization_demo.vis_utils import preprocess, Calibration, show_lidar_with_boxes + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_file", + type=str, + help="Model filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + "--params_file", + type=str, + help= + "Parameter filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + '--lidar_file', type=str, help='The lidar path.', required=True) + parser.add_argument( + '--calib_file', type=str, help='The lidar path.', required=True) + parser.add_argument( + "--num_point_dim", + type=int, + default=4, + help="Dimension of a point in the lidar file.") + parser.add_argument( + "--point_cloud_range", + dest='point_cloud_range', + nargs='+', + help="Range of point cloud for voxelize operation.", + type=float, + default=None) + parser.add_argument( + "--voxel_size", + dest='voxel_size', + nargs='+', + help="Size of voxels for voxelize operation.", + type=float, + default=None) + parser.add_argument( + "--max_points_in_voxel", + type=int, + default=100, + help="Maximum number of points in a voxel.") + parser.add_argument( + "--max_voxel_num", + type=int, + default=12000, + help="Maximum number of voxels.") + parser.add_argument("--gpu_id", type=int, default=0, help="GPU card id.") + parser.add_argument( + "--use_trt", + type=int, + default=0, + help="Whether to use tensorrt to accelerate when using gpu.") + parser.add_argument( + "--trt_precision", + type=int, + default=0, + help="Precision type of tensorrt, 0: kFloat32, 1: kHalf.") + parser.add_argument( + "--trt_use_static", + type=int, + default=0, + help="Whether to load the tensorrt graph optimization from a disk path." + ) + parser.add_argument( + "--trt_static_dir", + type=str, + help="Path of a tensorrt graph optimization directory.") + parser.add_argument( + "--collect_shape_info", + type=int, + default=0, + help="Whether to collect dynamic shape before using tensorrt.") + parser.add_argument( + "--dynamic_shape_file", + type=str, + default="", + help="Path of a dynamic shape file for tensorrt.") + + return parser.parse_args() + + +def init_predictor(model_file, + params_file, + gpu_id=0, + use_trt=False, + trt_precision=0, + trt_use_static=False, + trt_static_dir=None, + collect_shape_info=False, + dynamic_shape_file=None): + config = Config(model_file, params_file) + config.enable_memory_optim() + config.enable_use_gpu(1000, gpu_id) + if use_trt: + precision_mode = paddle.inference.PrecisionType.Float32 + if trt_precision == 1: + precision_mode = paddle.inference.PrecisionType.Half + config.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=1, + min_subgraph_size=10, + precision_mode=precision_mode, + use_static=trt_use_static, + use_calib_mode=False) + if collect_shape_info: + config.collect_shape_range_info(dynamic_shape_file) + else: + config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, True) + if trt_use_static: + config.set_optim_cache_dir(trt_static_dir) + + predictor = create_predictor(config) + return predictor + + +def run(predictor, voxels, coords, num_points_per_voxel): + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + if name == "voxels": + input_tensor.reshape(voxels.shape) + input_tensor.copy_from_cpu(voxels.copy()) + elif name == "coords": + input_tensor.reshape(coords.shape) + input_tensor.copy_from_cpu(coords.copy()) + elif name == "num_points_per_voxel": + input_tensor.reshape(num_points_per_voxel.shape) + input_tensor.copy_from_cpu(num_points_per_voxel.copy()) + + # do the inference + predictor.run() + + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + if i == 0: + box3d_lidar = output_tensor.copy_to_cpu() + elif i == 1: + label_preds = output_tensor.copy_to_cpu() + elif i == 2: + scores = output_tensor.copy_to_cpu() + return box3d_lidar, label_preds, scores + + +if __name__ == '__main__': + args = parse_args() + + predictor = init_predictor(args.model_file, args.params_file, args.gpu_id, + args.use_trt, args.trt_precision, + args.trt_use_static, args.trt_static_dir, + args.collect_shape_info, args.dynamic_shape_file) + voxels, coords, num_points_per_voxel = preprocess( + args.lidar_file, args.num_point_dim, args.point_cloud_range, + args.voxel_size, args.max_points_in_voxel, args.max_voxel_num) + box3d_lidar, label_preds, scores = run(predictor, voxels, coords, + num_points_per_voxel) + + scan = np.fromfile(args.lidar_file, dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + + # Obtain calibration information about Kitti + calib = Calibration(args.calib_file) + + # Plot box in lidar cloud + show_lidar_with_boxes(pc_velo, box3d_lidar, scores, calib) diff --git a/demo/visualization_demo/readme.md b/demo/visualization_demo/readme.md new file mode 100644 index 00000000..85bd098c --- /dev/null +++ b/demo/visualization_demo/readme.md @@ -0,0 +1,159 @@ +## 激光雷达点云/BEV和相机图像的3D可视化示例 +### 环境配置 +按照 [官方文档](https://github.com/PaddlePaddle/Paddle3D/blob/develop/docs/installation.md) 安装paddle3D依赖,然后安装`mayavi`用于激光点云可视化 +``` +pip install vtk==8.1.2 +pip install mayavi==4.7.4 +pip install PyQt5 +``` +### 相机图像的3D可视化示例 +相机图像的3D可视化文件保存在`demo/visualization_demo/`下,提供了单帧图像的3D可视化示例程序`mono_vis_single_frame_demo.py`和多帧图像的3D可视化示例程序`mono_vis_multi_frame_demo.py`。两者使用的可视化接口相同,对应的代码在`paddle3d.apis.infer`中。 + +`mono_vis_single_frame_demo.py`和`mono_vis_multi_frame_demo.py`的实现方法不同,以提供更多的可视化示例方法。其中`mono_vis_single_frame_demo.py`利用paddle推理部署的方式完成可视化,`mono_vis_multi_frame_demo.py`可视化通过对图像构建`dataloader`来完成逐帧读取和推理。 + +`mono_vis_single_frame_demo.py`使用方式如下: +``` +cd demo/visualization_demo +python mono_vis_single_frame_demo.py \ + --model_file model/smoke.pdmodel \ + --params_file model/smoke.pdiparams \ + --image data/image_2/000008.png +``` +`--model_file`和`--params_file`是使用的模型参数文件对应的路径 + +`--image`则是输入图像的路径 + +`mono_vis_multi_frame_demo.py`使用方式如下: + +``` +python mono_vis_multi_frame_demo.py \ + --config configs/smoke/smoke_dla34_no_dcn_kitti.yml \ + --model demo/smoke.pdparams \ + --batch_size 1 +``` + +`--config` 是模型配置文件路径 + +`--model` 是使用的模型参数文件对应的路径 + +`--batch_size` 是推理的batch数 + + +最终的单目可视化输出如下: + + +### 激光雷达点云的3D可视化示例 +激光雷达点云的3D可视化文件保存在`demo/visualization_demo/`下,提供了单帧激光雷达点云的3D可视化示例程序`pcd_vis_single_frame_demo.py`和多帧激光雷达点云的3D可视化示例程序`pcd_vis_multi_frame_demo.py`。两者使用的可视化接口相同,对应的代码在`paddle3d.apis.infer`中。 + +`pcd_vis_single_frame_demo.py`和`pcd_vis_multi_frame_demo.py`的实现方法不同,以提供更多的可视化示例方法。其中`pcd_vis_single_frame_demo.py`利用paddle推理部署的方式完成可视化,`pcd_vis_multi_frame_demo.py`可视化通过对图像构建`dataloader`来完成逐帧读取和推理。 + +`pcd_vis_single_frame_demo.py`使用方式如下: + +``` +cd demo/visualization_demo +python pcd_vis_single_frame_demo.py \ + --model_file model/pointpillars.pdmodel \ + --params_file model/pointpillars.pdiparams \ + --lidar_file data/velodyne/000008.bin \ + --calib_file data/calib/000008.txt \ + --point_cloud_range 0 -39.68 -3 69.12 39.68 1 \ + --voxel_size .16 .16 4 \ + --max_points_in_voxel 32 \ + --max_voxel_num 40000 +``` + +`--model_file`和`--params_file` 是使用的模型参数文件对应的路径 + +`--lidar_file` `--calib_file` 是激光雷达点云的路径和对应的校准文件路径 + +`--point_cloud_range` 表示激光雷达点云的`(x,y,z)`范围区间 + +`--voxel_size` 表示进行voxel处理时的尺寸大小 + +`--max_points_in_voxel` 每个voxel中最大的激光点云数目 + +`--max_voxel_num` voxel的最大数目 + +`pcd_vis_multi_frame_demo.py`使用方式如下: + +``` +python pcd_vis_multi_frame_demo.py \ + --config configs/pointpillars/pointpillars_xyres16_kitti_car.yml \ + --model demo/pointpillars.pdparams \ + --batch_size 1 +``` + +`--config` 是模型配置文件路径 + +`--model` 是使用的模型参数文件对应的路径 + +`--batch_size` 是推理的batch数 + +最终的激光雷达点云可视化输出如下: + + +### 激光雷达BEV的3D可视化示例 +激光雷达BEV的3D可视化文件保存在`demo/visualization_demo/`下,提供了单帧激光雷达点云BEV的3D可视化示例程序`bev_vis_single_frame_demo.py`和多帧激光雷达点云BEV的3D可视化示例程序`bev_vis_multi_frame_demo.py`。两者使用的可视化接口相同,对应的代码在`paddle3d.apis.infer`中。 + +`bev_vis_single_frame_demo.py`和`bev_vis_multi_frame_demo.py`的实现方法不同,以提供更多的可视化示例方法。其中`bev_vis_single_frame_demo.py`利用paddle推理部署的方式完成可视化,`bev_vis_multi_frame_demo.py`可视化通过对图像构建`dataloader`来完成逐帧读取和推理。 + +`bev_vis_single_frame_demo.py`使用方式如下: + +``` +cd demo/visualization_demo +python bev_vis_single_frame_demo.py \ + --model_file model/pointpillars.pdmodel \ + --params_file model/pointpillars.pdiparams \ + --lidar_file data/velodyne/000008.bin \ + --calib_file data/calib/000008.txt \ + --point_cloud_range 0 -39.68 -3 69.12 39.68 1 \ + --voxel_size .16 .16 4 \ + --max_points_in_voxel 32 \ + --max_voxel_num 40000 +``` +`--model_file`和`--params_file` 是使用的模型参数文件对应的路径 + +`--lidar_file` `--calib_file` 是激光雷达点云的路径和对应的校准文件路径 + +`--point_cloud_range` 表示激光雷达点云的`(x,y,z)`范围区间 + +`--voxel_size` 表示进行voxel处理时的尺寸大小 + +`--max_points_in_voxel` 每个voxel中最大的激光点云数目 + +`--max_voxel_num` voxel的最大数目 + +`bev_vis_multi_frame_demo.py`使用方式如下: + +``` +python bev_vis_multi_frame_demo.py \ + --config configs/pointpillars/pointpillars_xyres16_kitti_car.yml \ + --model demo/pointpillars.pdparams \ + --batch_size 1 +``` + +`--config` 是模型配置文件路径 + +`--model` 是使用的模型参数文件对应的路径 + +`--batch_size` 是推理的batch数 + +最终的激光雷达BEV可视化输出如下: + + + +### 数据集和LOG文件的可视化接口 +可视化接口对应的代码在`paddle3d.apis.infer`中,提供了一种调用示例 + +``` +cd demo/visualization_demo +python dataset_vis_demo.py +``` + +--- +如果遇到如下问题,可参考Ref1和Ref2的解决方案: + +`qt.qpa.plugin: Could not load the Qt Platform plugin 'xcb' in ..` + +[Ref1](https://blog.csdn.net/qq_39938666/article/details/120452028?spm=1001.2101.3001.6650.2&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-120452028-blog-112303826.pc_relevant_3mothn_strategy_recovery&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-120452028-blog-112303826.pc_relevant_3mothn_strategy_recovery&utm_relevant_index=3) +& [Ref2](https://blog.csdn.net/weixin_41794514/article/details/128578166?spm=1001.2101.3001.6650.3&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-128578166-blog-119480436.pc_relevant_landingrelevant&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-128578166-blog-119480436.pc_relevant_landingrelevant) diff --git a/demo/visualization_demo/vis_utils.py b/demo/visualization_demo/vis_utils.py new file mode 100644 index 00000000..99d1019d --- /dev/null +++ b/demo/visualization_demo/vis_utils.py @@ -0,0 +1,719 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import numba +import numpy as np +import mayavi.mlab as mlab + +from paddle3d.transforms.target_generator import encode_label + + +class Calibration(object): + ''' Calibration matrices and utils + 3d XYZ in <label>.txt are in rect camera coord. + 2d box xy are in image2 coord + Points in <lidar>.bin are in Velodyne coord. + y_image2 = P^2_rect * x_rect + y_image2 = P^2_rect * R0_rect * Tr_velo_to_cam * x_velo + x_ref = Tr_velo_to_cam * x_velo + x_rect = R0_rect * x_ref + P^2_rect = [f^2_u, 0, c^2_u, -f^2_u b^2_x; + 0, f^2_v, c^2_v, -f^2_v b^2_y; + 0, 0, 1, 0] + = K * [1|t] + image2 coord: + ----> x-axis (u) + | + | + v y-axis (v) + velodyne coord: + front x, left y, up z + rect/ref camera coord: + right x, down y, front z + Ref (KITTI paper): http://www.cvlibs.net/publications/Geiger2013IJRR.pdf + TODO(rqi): do matrix multiplication only once for each projection. + ''' + + def __init__(self, calib_filepath, from_video=False): + if from_video: + calibs = self.read_calib_from_video(calib_filepath) + else: + calibs = self.read_calib_file(calib_filepath) + # Projection matrix from rect camera coord to image2 coord + self.P = calibs['P2'] + self.P = np.reshape(self.P, [3, 4]) + # Rigid transform from Velodyne coord to reference camera coord + self.V2C = calibs['Tr_velo_to_cam'] + self.V2C = np.reshape(self.V2C, [3, 4]) + self.C2V = inverse_rigid_trans(self.V2C) + # Rotation from reference camera coord to rect camera coord + self.R0 = calibs['R0_rect'] + self.R0 = np.reshape(self.R0, [3, 3]) + + # Camera intrinsics and extrinsics + self.c_u = self.P[0, 2] + self.c_v = self.P[1, 2] + self.f_u = self.P[0, 0] + self.f_v = self.P[1, 1] + self.b_x = self.P[0, 3] / (-self.f_u) # relative + self.b_y = self.P[1, 3] / (-self.f_v) + + def read_calib_file(self, filepath): + ''' Read in a calibration file and parse into a dictionary. + Ref: https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py + ''' + data = {} + with open(filepath, 'r') as f: + for line in f.readlines(): + line = line.rstrip() + if len(line) == 0: continue + key, value = line.split(':', 1) + # The only non-float values in these files are dates, which + # we don't care about anyway + try: + data[key] = np.array([float(x) for x in value.split()]) + except ValueError: + pass + + return data + + def read_calib_from_video(self, calib_root_dir): + ''' Read calibration for camera 2 from video calib files. + there are calib_cam_to_cam and calib_velo_to_cam under the calib_root_dir + ''' + data = {} + cam2cam = self.read_calib_file( + os.path.join(calib_root_dir, 'calib_cam_to_cam.txt')) + velo2cam = self.read_calib_file( + os.path.join(calib_root_dir, 'calib_velo_to_cam.txt')) + Tr_velo_to_cam = np.zeros((3, 4)) + Tr_velo_to_cam[0:3, 0:3] = np.reshape(velo2cam['R'], [3, 3]) + Tr_velo_to_cam[:, 3] = velo2cam['T'] + data['Tr_velo_to_cam'] = np.reshape(Tr_velo_to_cam, [12]) + data['R0_rect'] = cam2cam['R_rect_00'] + data['P2'] = cam2cam['P_rect_02'] + return data + + def cart2hom(self, pts_3d): + ''' Input: nx3 points in Cartesian + Oupput: nx4 points in Homogeneous by pending 1 + ''' + n = pts_3d.shape[0] + pts_3d_hom = np.hstack((pts_3d, np.ones((n, 1)))) + return pts_3d_hom + + # 3d to 3d + def project_velo_to_ref(self, pts_3d_velo): + pts_3d_velo = self.cart2hom(pts_3d_velo) # nx4 + return np.dot(pts_3d_velo, np.transpose(self.V2C)) + + def project_ref_to_velo(self, pts_3d_ref): + pts_3d_ref = self.cart2hom(pts_3d_ref) # nx4 + return np.dot(pts_3d_ref, np.transpose(self.C2V)) + + def project_rect_to_ref(self, pts_3d_rect): + ''' Input and Output are nx3 points ''' + return np.transpose( + np.dot(np.linalg.inv(self.R0), np.transpose(pts_3d_rect))) + + def project_ref_to_rect(self, pts_3d_ref): + ''' Input and Output are nx3 points ''' + return np.transpose(np.dot(self.R0, np.transpose(pts_3d_ref))) + + def project_rect_to_velo(self, pts_3d_rect): + ''' Input: nx3 points in rect camera coord. + Output: nx3 points in velodyne coord. + ''' + pts_3d_ref = self.project_rect_to_ref(pts_3d_rect) + return self.project_ref_to_velo(pts_3d_ref) + + def project_velo_to_rect(self, pts_3d_velo): + pts_3d_ref = self.project_velo_to_ref(pts_3d_velo) + return self.project_ref_to_rect(pts_3d_ref) + + # 3d to 2d + def project_rect_to_image(self, pts_3d_rect): + ''' Input: nx3 points in rect camera coord. + Output: nx2 points in image2 coord. + ''' + pts_3d_rect = self.cart2hom(pts_3d_rect) + pts_2d = np.dot(pts_3d_rect, np.transpose(self.P)) # nx3 + pts_2d[:, 0] /= pts_2d[:, 2] + pts_2d[:, 1] /= pts_2d[:, 2] + return pts_2d[:, 0:2] + + def project_velo_to_image(self, pts_3d_velo): + ''' Input: nx3 points in velodyne coord. + Output: nx2 points in image2 coord. + ''' + pts_3d_rect = self.project_velo_to_rect(pts_3d_velo) + return self.project_rect_to_image(pts_3d_rect) + + # 2d to 3d + def project_image_to_rect(self, uv_depth): + ''' Input: nx3 first two channels are uv, 3rd channel + is depth in rect camera coord. + Output: nx3 points in rect camera coord. + ''' + n = uv_depth.shape[0] + x = ((uv_depth[:, 0] - self.c_u) * uv_depth[:, 2]) / self.f_u + self.b_x + y = ((uv_depth[:, 1] - self.c_v) * uv_depth[:, 2]) / self.f_v + self.b_y + pts_3d_rect = np.zeros((n, 3)) + pts_3d_rect[:, 0] = x + pts_3d_rect[:, 1] = y + pts_3d_rect[:, 2] = uv_depth[:, 2] + return pts_3d_rect + + def project_image_to_velo(self, uv_depth): + pts_3d_rect = self.project_image_to_rect(uv_depth) + return self.project_rect_to_velo(pts_3d_rect) + + +def get_ratio(ori_img_size, output_size, down_ratio=(4, 4)): + return np.array([[ + down_ratio[1] * ori_img_size[1] / output_size[1], + down_ratio[0] * ori_img_size[0] / output_size[0] + ]], np.float32) + + +def get_img(img_path): + img = cv2.imread(img_path) + origin_shape = img.shape + img = cv2.resize(img, (1280, 384)) + + target_shape = img.shape + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = img / 255.0 + img = np.subtract(img, np.array([0.485, 0.456, 0.406])) + img = np.true_divide(img, np.array([0.229, 0.224, 0.225])) + img = np.array(img, np.float32) + + img = img.transpose(2, 0, 1) + img = img[None, :, :, :] + + return img, origin_shape, target_shape + + +def total_pred_by_conf_to_kitti_records( + total_pred, conf, class_names=["Car", "Cyclist", "Pedestrian"]): + """convert total_pred to kitti_records""" + kitti_records_list = [] + for p in total_pred: + if p[-1] > conf: + p = list(p) + p[0] = class_names[int(p[0])] + # default, to kitti_records formate + p.insert(1, 0.0) + p.insert(2, 0) + kitti_records_list.append(p) + kitti_records = np.array(kitti_records_list) + + return kitti_records + + +def total_imgpred_by_conf_to_kitti_records( + total_pred, conf, class_names=["Car", "Cyclist", "Pedestrian"]): + """convert total_pred to kitti_records""" + kitti_records_list = [] + for idx in range(len(total_pred['confidences'])): + box2d = total_pred['bboxes_2d'][idx] + box3d = total_pred['bboxes_3d'][idx] + label = total_pred['labels'][idx] + cnf = total_pred['confidences'][idx] + if cnf > conf: + p = [] + p.append(class_names[int(label)]) + # default, to kitti_records formate + p.extend([0.0, 0.0, 0.0]) + p.extend(box2d) + p.extend([ + box3d[5], box3d[4], box3d[3], box3d[0], box3d[1], box3d[2], + box3d[-1] + ]) + kitti_records_list.append(p) + kitti_records = np.array(kitti_records_list) + + return kitti_records + + +def make_imgpts_list(bboxes_3d, K): + """to 8 points on image""" + # external parameters do not transform + rvec = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + tvec = np.array([[0.0], [0.0], [0.0]]) + + imgpts_list = [] + for box3d in bboxes_3d: + + locs = np.array(box3d[0:3]) + rot_y = np.array(box3d[6]) + + height, width, length = box3d[3:6] + _, box2d, box3d = encode_label(K, rot_y, + np.array([length, height, width]), locs) + + if np.all(box2d == 0): + continue + + imgpts, _ = cv2.projectPoints(box3d.T, rvec, tvec, K, 0) + imgpts_list.append(imgpts) + + return imgpts_list + + +def draw_mono_3d(img, imgpts_list): + """draw smoke result to photo""" + connect_line_id = [ + [1, 0], + [2, 7], + [3, 6], + [4, 5], + [1, 2], + [2, 3], + [3, 4], + [4, 1], + [0, 7], + [7, 6], + [6, 5], + [5, 0], + ] + + img_draw = img.copy() + + for imgpts in imgpts_list: + for p in imgpts: + p_x, p_y = int(p[0][0]), int(p[0][1]) + cv2.circle(img_draw, (p_x, p_y), 1, (0, 255, 0), -1) + for i, line_id in enumerate(connect_line_id): + + p1 = (int(imgpts[line_id[0]][0][0]), int(imgpts[line_id[0]][0][1])) + p2 = (int(imgpts[line_id[1]][0][0]), int(imgpts[line_id[1]][0][1])) + + if i <= 3: # body + color = (255, 0, 0) + elif i <= 7: # head + color = (0, 0, 255) + else: # tail + color = (255, 255, 0) + + cv2.line(img_draw, p1, p2, color, 1) + + cv2.imshow('output', img_draw) + cv2.waitKey(0) + cv2.destroyWindow('output') + + +def read_point(file_path, num_point_dim): + points = np.fromfile(file_path, np.float32).reshape(-1, num_point_dim) + points = points[:, :4] + return points + + +@numba.jit(nopython=True) +def _points_to_voxel(points, voxel_size, point_cloud_range, grid_size, voxels, + coords, num_points_per_voxel, grid_idx_to_voxel_idx, + max_points_in_voxel, max_voxel_num): + num_voxels = 0 + num_points = points.shape[0] + # x, y, z + coord = np.zeros(shape=(3, ), dtype=np.int32) + + for point_idx in range(num_points): + outside = False + for i in range(3): + coord[i] = np.floor( + (points[point_idx, i] - point_cloud_range[i]) / voxel_size[i]) + if coord[i] < 0 or coord[i] >= grid_size[i]: + outside = True + break + if outside: + continue + voxel_idx = grid_idx_to_voxel_idx[coord[2], coord[1], coord[0]] + if voxel_idx == -1: + voxel_idx = num_voxels + if num_voxels >= max_voxel_num: + continue + num_voxels += 1 + grid_idx_to_voxel_idx[coord[2], coord[1], coord[0]] = voxel_idx + coords[voxel_idx, 0:3] = coord[::-1] + curr_num_point = num_points_per_voxel[voxel_idx] + if curr_num_point < max_points_in_voxel: + voxels[voxel_idx, curr_num_point] = points[point_idx] + num_points_per_voxel[voxel_idx] = curr_num_point + 1 + + return num_voxels + + +def hardvoxelize(points, point_cloud_range, voxel_size, max_points_in_voxel, + max_voxel_num): + num_points, num_point_dim = points.shape[0:2] + point_cloud_range = np.array(point_cloud_range) + voxel_size = np.array(voxel_size) + voxels = np.zeros((max_voxel_num, max_points_in_voxel, num_point_dim), + dtype=points.dtype) + coords = np.zeros((max_voxel_num, 3), dtype=np.int32) + num_points_per_voxel = np.zeros((max_voxel_num, ), dtype=np.int32) + grid_size = np.round((point_cloud_range[3:6] - point_cloud_range[0:3]) / + voxel_size).astype('int32') + + grid_size_x, grid_size_y, grid_size_z = grid_size + + grid_idx_to_voxel_idx = np.full((grid_size_z, grid_size_y, grid_size_x), + -1, + dtype=np.int32) + + num_voxels = _points_to_voxel(points, voxel_size, point_cloud_range, + grid_size, voxels, coords, + num_points_per_voxel, grid_idx_to_voxel_idx, + max_points_in_voxel, max_voxel_num) + + voxels = voxels[:num_voxels] + coords = coords[:num_voxels] + num_points_per_voxel = num_points_per_voxel[:num_voxels] + + return voxels, coords, num_points_per_voxel + + +def preprocess(file_path, num_point_dim, point_cloud_range, voxel_size, + max_points_in_voxel, max_voxel_num): + points = read_point(file_path, num_point_dim) + voxels, coords, num_points_per_voxel = hardvoxelize( + points, point_cloud_range, voxel_size, max_points_in_voxel, + max_voxel_num) + + return voxels, coords, num_points_per_voxel + + +def roty(t): + ''' Rotation about the y-axis. ''' + c = np.cos(t) + s = np.sin(t) + return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]) + + +def inverse_rigid_trans(Tr): + ''' Inverse a rigid body transform matrix (3x4 as [R|t]) + [R'|-R't; 0|1] + ''' + inv_Tr = np.zeros_like(Tr) # 3x4 + inv_Tr[0:3, 0:3] = np.transpose(Tr[0:3, 0:3]) + inv_Tr[0:3, 3] = np.dot(-np.transpose(Tr[0:3, 0:3]), Tr[0:3, 3]) + return inv_Tr + + +def compute_box_3d(obj): + ''' Takes an object and a projection matrix (P) and projects the 3d + bounding box into the image plane. + Returns: + corners_2d: (8,2) array in left image coord. + corners_3d: (8,3) array in in rect camera coord. + ''' + # compute rotational matrix around yaw axis + R = roty(obj[-1]) + + # 3d bounding box dimensions + l = obj[4] + w = obj[3] + h = obj[5] + + # 3d bounding box corners + x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2] + y_corners = [0, 0, 0, 0, -h, -h, -h, -h] + z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2] + + # rotate and translate 3d bounding box + corners_3d = np.dot(R, np.vstack([x_corners, y_corners, z_corners])) + # print corners_3d.shape + corners_3d[0, :] = corners_3d[0, :] - obj[1] + corners_3d[1, :] = corners_3d[1, :] - obj[2] + corners_3d[2, :] = corners_3d[2, :] + obj[0] + + return np.transpose(corners_3d) + + +def draw_lidar(pc, + color=None, + fig=None, + bgcolor=(0, 0, 0), + pts_scale=1, + pts_mode='point', + pts_color=None): + ''' Draw lidar points + Args: + pc: numpy array (n,3) of XYZ + color: numpy array (n) of intensity or whatever + fig: mayavi figure handler, if None create new one otherwise will use it + Returns: + fig: created or used fig + ''' + if fig is None: + fig = mlab.figure( + figure=None, + bgcolor=bgcolor, + fgcolor=None, + engine=None, + size=(1600, 1000)) + if color is None: color = pc[:, 2] + mlab.points3d( + pc[:, 0], + pc[:, 1], + pc[:, 2], + color, + color=pts_color, + mode=pts_mode, + colormap='gnuplot', + scale_factor=pts_scale, + figure=fig) + + # draw origin + mlab.points3d(0, 0, 0, color=(1, 1, 1), mode='sphere', scale_factor=0.2) + + # draw axis + axes = np.array([ + [2., 0., 0., 0.], + [0., 2., 0., 0.], + [0., 0., 2., 0.], + ], + dtype=np.float64) + mlab.plot3d([0, axes[0, 0]], [0, axes[0, 1]], [0, axes[0, 2]], + color=(1, 0, 0), + tube_radius=None, + figure=fig) + mlab.plot3d([0, axes[1, 0]], [0, axes[1, 1]], [0, axes[1, 2]], + color=(0, 1, 0), + tube_radius=None, + figure=fig) + mlab.plot3d([0, axes[2, 0]], [0, axes[2, 1]], [0, axes[2, 2]], + color=(0, 0, 1), + tube_radius=None, + figure=fig) + + # draw fov (todo: update to real sensor spec.) + fov = np.array( + [ # 45 degree + [20., 20., 0., 0.], + [20., -20., 0., 0.], + ], + dtype=np.float64) + + mlab.plot3d([0, fov[0, 0]], [0, fov[0, 1]], [0, fov[0, 2]], + color=(1, 1, 1), + tube_radius=None, + line_width=1, + figure=fig) + mlab.plot3d([0, fov[1, 0]], [0, fov[1, 1]], [0, fov[1, 2]], + color=(1, 1, 1), + tube_radius=None, + line_width=1, + figure=fig) + + # draw square region + TOP_Y_MIN = -20 + TOP_Y_MAX = 20 + TOP_X_MIN = 0 + TOP_X_MAX = 40 + TOP_Z_MIN = -2.0 + TOP_Z_MAX = 0.4 + + x1 = TOP_X_MIN + x2 = TOP_X_MAX + y1 = TOP_Y_MIN + y2 = TOP_Y_MAX + mlab.plot3d([x1, x1], [y1, y2], [0, 0], + color=(0.5, 0.5, 0.5), + tube_radius=0.1, + line_width=1, + figure=fig) + mlab.plot3d([x2, x2], [y1, y2], [0, 0], + color=(0.5, 0.5, 0.5), + tube_radius=0.1, + line_width=1, + figure=fig) + mlab.plot3d([x1, x2], [y1, y1], [0, 0], + color=(0.5, 0.5, 0.5), + tube_radius=0.1, + line_width=1, + figure=fig) + mlab.plot3d([x1, x2], [y2, y2], [0, 0], + color=(0.5, 0.5, 0.5), + tube_radius=0.1, + line_width=1, + figure=fig) + + # mlab.orientation_axes() + mlab.view( + azimuth=180, + elevation=70, + focalpoint=[12.0909996, -1.04700089, -2.03249991], + distance=62.0, + figure=fig) + return fig + + +def draw_gt_boxes3d(gt_boxes3d, + fig, + color=(1, 1, 1), + line_width=1, + draw_text=True, + text_scale=(1, 1, 1), + color_list=None): + ''' Draw 3D bounding boxes + Args: + gt_boxes3d: numpy array (n,8,3) for XYZs of the box corners + fig: mayavi figure handler + color: RGB value tuple in range (0,1), box line color + line_width: box line width + draw_text: boolean, if true, write box indices beside boxes + text_scale: three number tuple + color_list: a list of RGB tuple, if not None, overwrite color. + Returns: + fig: updated fig + ''' + num = len(gt_boxes3d) + for n in range(num): + b = gt_boxes3d[n] + if color_list is not None: + color = color_list[n] + if draw_text: + mlab.text3d( + b[4, 0], + b[4, 1], + b[4, 2], + '%d' % n, + scale=text_scale, + color=color, + figure=fig) + for k in range(0, 4): + # http://docs.enthought.com/mayavi/mayavi/auto/mlab_helper_functions.html + i, j = k, (k + 1) % 4 + mlab.plot3d([b[i, 0], b[j, 0]], [b[i, 1], b[j, 1]], + [b[i, 2], b[j, 2]], + color=color, + tube_radius=None, + line_width=line_width, + figure=fig) + + i, j = k + 4, (k + 1) % 4 + 4 + mlab.plot3d([b[i, 0], b[j, 0]], [b[i, 1], b[j, 1]], + [b[i, 2], b[j, 2]], + color=color, + tube_radius=None, + line_width=line_width, + figure=fig) + + i, j = k, k + 4 + mlab.plot3d([b[i, 0], b[j, 0]], [b[i, 1], b[j, 1]], + [b[i, 2], b[j, 2]], + color=color, + tube_radius=None, + line_width=line_width, + figure=fig) + return + + +def show_lidar_with_boxes(pc_velo, objects, scores, calib): + ''' Show all LiDAR points. + Draw 3d box in LiDAR point cloud (in velo coord system) ''' + + fig = mlab.figure( + figure=None, + bgcolor=(0, 0, 0), + fgcolor=None, + engine=None, + size=(1000, 500)) + + draw_lidar(pc_velo, fig=fig) + + num_bbox3d, bbox3d_dims = objects.shape + for box_idx in range(num_bbox3d): + # filter fake results: score = -1 + if scores[box_idx] <= 0.3: + continue + obj = objects[box_idx] + # Draw 3d bounding box + box3d_pts_3d = compute_box_3d(obj) + box3d_pts_3d_velo = calib.project_rect_to_velo(box3d_pts_3d) + + draw_gt_boxes3d([box3d_pts_3d_velo], fig=fig) + mlab.show() + + +def pts2bev(pts): + side_range = (-75, 75) + fwd_range = (0, 75) + height_range = (-2, 5) + x, y, z = pts[:, 0], pts[:, 1], pts[:, 2] + + f_filter = np.logical_and(x > fwd_range[0], x < fwd_range[1]) + s_filter = np.logical_and(y > side_range[0], y < side_range[1]) + h_filter = np.logical_and(z > height_range[0], z < height_range[1]) + filter = np.logical_and(f_filter, s_filter) + filter = np.logical_and(filter, h_filter) + indices = np.argwhere(filter).flatten() + x, y, z = x[indices], y[indices], z[indices] + + res = 0.25 + x_img = (-y / res).astype(np.int32) + y_img = (-x / res).astype(np.int32) + x_img = x_img - int(np.floor(side_range[0]) / res) + y_img = y_img + int(np.floor(fwd_range[1]) / res) + + pixel_value = [255, 255, 255] + + x_max = int((side_range[1] - side_range[0]) / res) + 1 + y_max = int((fwd_range[1] - fwd_range[0]) / res) + 1 + + im = np.zeros([y_max, x_max, 3], dtype=np.uint8) + im[y_img, x_img] = pixel_value + + return im[:, :] + + +def show_bev_with_boxes(pc_velo, objects, scores, calib): + bev_im = pts2bev(pc_velo) + num_bbox3d, bbox3d_dims = objects.shape + for box_idx in range(num_bbox3d): + # filter results + if scores[box_idx] <= 0.3: + continue + obj = objects[box_idx] + # Draw bev bounding box + box3d_pts_3d = compute_box_3d(obj) + box3d_pts_3d_velo = calib.project_rect_to_velo(box3d_pts_3d) + + bpts = box3d_pts_3d_velo[:4, :2] + bpts = bpts[:, [1, 0]] + + cv2.line(bev_im, + (int(-bpts[0, 0] * 4 + 300), int(300 - bpts[0, 1] * 4)), + (int(-bpts[1, 0] * 4 + 300), int(300 - bpts[1, 1] * 4)), + (0, 0, 255), 2) + cv2.line(bev_im, + (int(-bpts[1, 0] * 4 + 300), int(300 - bpts[1, 1] * 4)), + (int(-bpts[2, 0] * 4 + 300), int(300 - bpts[2, 1] * 4)), + (0, 0, 255), 2) + cv2.line(bev_im, + (int(-bpts[2, 0] * 4 + 300), int(300 - bpts[2, 1] * 4)), + (int(-bpts[3, 0] * 4 + 300), int(300 - bpts[3, 1] * 4)), + (0, 0, 255), 2) + cv2.line(bev_im, + (int(-bpts[3, 0] * 4 + 300), int(300 - bpts[3, 1] * 4)), + (int(-bpts[0, 0] * 4 + 300), int(300 - bpts[0, 1] * 4)), + (0, 0, 255), 2) + cv2.imshow('bev', bev_im) + cv2.waitKey(0) + cv2.destroyWindow('bev') diff --git a/docs/visualization/visualization_cn.md b/docs/visualization/visualization_cn.md new file mode 100644 index 00000000..85bd098c --- /dev/null +++ b/docs/visualization/visualization_cn.md @@ -0,0 +1,159 @@ +## 激光雷达点云/BEV和相机图像的3D可视化示例 +### 环境配置 +按照 [官方文档](https://github.com/PaddlePaddle/Paddle3D/blob/develop/docs/installation.md) 安装paddle3D依赖,然后安装`mayavi`用于激光点云可视化 +``` +pip install vtk==8.1.2 +pip install mayavi==4.7.4 +pip install PyQt5 +``` +### 相机图像的3D可视化示例 +相机图像的3D可视化文件保存在`demo/visualization_demo/`下,提供了单帧图像的3D可视化示例程序`mono_vis_single_frame_demo.py`和多帧图像的3D可视化示例程序`mono_vis_multi_frame_demo.py`。两者使用的可视化接口相同,对应的代码在`paddle3d.apis.infer`中。 + +`mono_vis_single_frame_demo.py`和`mono_vis_multi_frame_demo.py`的实现方法不同,以提供更多的可视化示例方法。其中`mono_vis_single_frame_demo.py`利用paddle推理部署的方式完成可视化,`mono_vis_multi_frame_demo.py`可视化通过对图像构建`dataloader`来完成逐帧读取和推理。 + +`mono_vis_single_frame_demo.py`使用方式如下: +``` +cd demo/visualization_demo +python mono_vis_single_frame_demo.py \ + --model_file model/smoke.pdmodel \ + --params_file model/smoke.pdiparams \ + --image data/image_2/000008.png +``` +`--model_file`和`--params_file`是使用的模型参数文件对应的路径 + +`--image`则是输入图像的路径 + +`mono_vis_multi_frame_demo.py`使用方式如下: + +``` +python mono_vis_multi_frame_demo.py \ + --config configs/smoke/smoke_dla34_no_dcn_kitti.yml \ + --model demo/smoke.pdparams \ + --batch_size 1 +``` + +`--config` 是模型配置文件路径 + +`--model` 是使用的模型参数文件对应的路径 + +`--batch_size` 是推理的batch数 + + +最终的单目可视化输出如下: + + +### 激光雷达点云的3D可视化示例 +激光雷达点云的3D可视化文件保存在`demo/visualization_demo/`下,提供了单帧激光雷达点云的3D可视化示例程序`pcd_vis_single_frame_demo.py`和多帧激光雷达点云的3D可视化示例程序`pcd_vis_multi_frame_demo.py`。两者使用的可视化接口相同,对应的代码在`paddle3d.apis.infer`中。 + +`pcd_vis_single_frame_demo.py`和`pcd_vis_multi_frame_demo.py`的实现方法不同,以提供更多的可视化示例方法。其中`pcd_vis_single_frame_demo.py`利用paddle推理部署的方式完成可视化,`pcd_vis_multi_frame_demo.py`可视化通过对图像构建`dataloader`来完成逐帧读取和推理。 + +`pcd_vis_single_frame_demo.py`使用方式如下: + +``` +cd demo/visualization_demo +python pcd_vis_single_frame_demo.py \ + --model_file model/pointpillars.pdmodel \ + --params_file model/pointpillars.pdiparams \ + --lidar_file data/velodyne/000008.bin \ + --calib_file data/calib/000008.txt \ + --point_cloud_range 0 -39.68 -3 69.12 39.68 1 \ + --voxel_size .16 .16 4 \ + --max_points_in_voxel 32 \ + --max_voxel_num 40000 +``` + +`--model_file`和`--params_file` 是使用的模型参数文件对应的路径 + +`--lidar_file` `--calib_file` 是激光雷达点云的路径和对应的校准文件路径 + +`--point_cloud_range` 表示激光雷达点云的`(x,y,z)`范围区间 + +`--voxel_size` 表示进行voxel处理时的尺寸大小 + +`--max_points_in_voxel` 每个voxel中最大的激光点云数目 + +`--max_voxel_num` voxel的最大数目 + +`pcd_vis_multi_frame_demo.py`使用方式如下: + +``` +python pcd_vis_multi_frame_demo.py \ + --config configs/pointpillars/pointpillars_xyres16_kitti_car.yml \ + --model demo/pointpillars.pdparams \ + --batch_size 1 +``` + +`--config` 是模型配置文件路径 + +`--model` 是使用的模型参数文件对应的路径 + +`--batch_size` 是推理的batch数 + +最终的激光雷达点云可视化输出如下: + + +### 激光雷达BEV的3D可视化示例 +激光雷达BEV的3D可视化文件保存在`demo/visualization_demo/`下,提供了单帧激光雷达点云BEV的3D可视化示例程序`bev_vis_single_frame_demo.py`和多帧激光雷达点云BEV的3D可视化示例程序`bev_vis_multi_frame_demo.py`。两者使用的可视化接口相同,对应的代码在`paddle3d.apis.infer`中。 + +`bev_vis_single_frame_demo.py`和`bev_vis_multi_frame_demo.py`的实现方法不同,以提供更多的可视化示例方法。其中`bev_vis_single_frame_demo.py`利用paddle推理部署的方式完成可视化,`bev_vis_multi_frame_demo.py`可视化通过对图像构建`dataloader`来完成逐帧读取和推理。 + +`bev_vis_single_frame_demo.py`使用方式如下: + +``` +cd demo/visualization_demo +python bev_vis_single_frame_demo.py \ + --model_file model/pointpillars.pdmodel \ + --params_file model/pointpillars.pdiparams \ + --lidar_file data/velodyne/000008.bin \ + --calib_file data/calib/000008.txt \ + --point_cloud_range 0 -39.68 -3 69.12 39.68 1 \ + --voxel_size .16 .16 4 \ + --max_points_in_voxel 32 \ + --max_voxel_num 40000 +``` +`--model_file`和`--params_file` 是使用的模型参数文件对应的路径 + +`--lidar_file` `--calib_file` 是激光雷达点云的路径和对应的校准文件路径 + +`--point_cloud_range` 表示激光雷达点云的`(x,y,z)`范围区间 + +`--voxel_size` 表示进行voxel处理时的尺寸大小 + +`--max_points_in_voxel` 每个voxel中最大的激光点云数目 + +`--max_voxel_num` voxel的最大数目 + +`bev_vis_multi_frame_demo.py`使用方式如下: + +``` +python bev_vis_multi_frame_demo.py \ + --config configs/pointpillars/pointpillars_xyres16_kitti_car.yml \ + --model demo/pointpillars.pdparams \ + --batch_size 1 +``` + +`--config` 是模型配置文件路径 + +`--model` 是使用的模型参数文件对应的路径 + +`--batch_size` 是推理的batch数 + +最终的激光雷达BEV可视化输出如下: + + + +### 数据集和LOG文件的可视化接口 +可视化接口对应的代码在`paddle3d.apis.infer`中,提供了一种调用示例 + +``` +cd demo/visualization_demo +python dataset_vis_demo.py +``` + +--- +如果遇到如下问题,可参考Ref1和Ref2的解决方案: + +`qt.qpa.plugin: Could not load the Qt Platform plugin 'xcb' in ..` + +[Ref1](https://blog.csdn.net/qq_39938666/article/details/120452028?spm=1001.2101.3001.6650.2&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-120452028-blog-112303826.pc_relevant_3mothn_strategy_recovery&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-120452028-blog-112303826.pc_relevant_3mothn_strategy_recovery&utm_relevant_index=3) +& [Ref2](https://blog.csdn.net/weixin_41794514/article/details/128578166?spm=1001.2101.3001.6650.3&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-128578166-blog-119480436.pc_relevant_landingrelevant&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-128578166-blog-119480436.pc_relevant_landingrelevant) diff --git a/docs/visualization/visualization_en.md b/docs/visualization/visualization_en.md new file mode 100644 index 00000000..12f9ba09 --- /dev/null +++ b/docs/visualization/visualization_en.md @@ -0,0 +1,159 @@ +## 3D Visualization Examples of LiDAR Point Cloud/BEV and Camera Images +### Environment Setup +Follow the [official documentation](https://github.com/PaddlePaddle/Paddle3D/blob/develop/docs/installation.md) to install the dependencies for Paddle3D. Then install mayavi for laser point cloud visualization using the following commands: +``` +pip install vtk==8.1.2 +pip install mayavi==4.7.4 +pip install PyQt5 +``` +### 3D Visualization Example of Camera Images +The 3D visualization files for camera images are located in `demo/visualization_demo/`. The folder provides two examples: `mono_vis_single_frame_demo.py` for visualizing a single frame and `mono_vis_multi_frame_demo.py` for visualizing multiple frames. Both examples use the same visualization interface, which is defined in `paddle3d.apis.infer`. + +The implementation methods of `mono_vis_single_frame_demo.py` and `mono_vis_multi_frame_demo.py` are different to provide more visualization options. In `mono_vis_single_frame_demo.py`, visualization is achieved using the inference deployment with Paddle, while `mono_vis_multi_frame_demo.py` performs sequential frame reading and inference by constructing a dataloader for visualization. + +To use `mono_vis_single_frame_demo.py`, use the following command: +``` +cd demo/visualization_demo +python mono_vis_single_frame_demo.py \ + --model_file model/smoke.pdmodel \ + --params_file model/smoke.pdiparams \ + --image data/image_2/000008.png +``` +`--model_file` and `--params_file` are the paths to the model parameter files being used. + +`--image` is the path of the input image. + +To use `mono_vis_multi_frame_demo.py`, use the following command: + +``` +python mono_vis_multi_frame_demo.py \ + --config configs/smoke/smoke_dla34_no_dcn_kitti.yml \ + --model demo/smoke.pdparams \ + --batch_size 1 +``` + +`--config` is the path to the model configuration file. + +`--model` is the path to the model parameter file being used. + +`--batch_size` is the batch size for inference. + +The final output of the monocular visualization is shown below: + + +### 3D Visualization Example of LiDAR Point Cloud +The 3D visualization files for LiDAR point clouds are located in `demo/visualization_demo/`. The folder provides two examples: `pcd_vis_single_frame_demo.py` for visualizing a single frame and `pcd_vis_multi_frame_demo.py` for visualizing multiple frames. Both examples use the same visualization interface, which is defined in `paddle3d.apis.infer`. + +The implementation methods of `pcd_vis_single_frame_demo.py` and `pcd_vis_multi_frame_demo.py` are different to provide more visualization options. In `pcd_vis_single_frame_demo.py`, visualization is achieved using the inference deployment with Paddle, while `pcd_vis_multi_frame_demo.py` performs sequential frame reading and inference by constructing a dataloader for visualization. + +To use `pcd_vis_single_frame_demo.py`, use the following command: + +``` +cd demo/visualization_demo +python pcd_vis_single_frame_demo.py \ + --model_file model/pointpillars.pdmodel \ + --params_file model/pointpillars.pdiparams \ + --lidar_file data/velodyne/000008.bin \ + --calib_file data/calib/000008.txt \ + --point_cloud_range 0 -39.68 -3 69.12 39.68 1 \ + --voxel_size .16 .16 4 \ + --max_points_in_voxel 32 \ + --max_voxel_num 40000 +``` +`--model_file` and `--params_file` are the paths to the model parameter files being used. + +`--lidar_file` and `--calib_file` are the paths to the LiDAR point cloud and its corresponding calibration file. + +`--point_cloud_range` represents the range of the LiDAR point cloud in `(x, y, z)` coordinates. + +`--voxel_size` represents the size of the voxels used in processing. + +`--max_points_in_voxel` is the maximum number of LiDAR points in each voxel. + +`--max_voxel_num` is the maximum number of voxels. + + +To use `pcd_vis_multi_frame_demo.py`, use the following command: + +``` +python pcd_vis_multi_frame_demo.py \ + --config configs/pointpillars/pointpillars_xyres16_kitti_car.yml \ + --model demo/pointpillars.pdparams \ + --batch_size 1 +``` + +`--config` is the path to the model configuration file. + +`--model` is the path to the model parameter file being used. + +`--batch_size` is the batch size for inference. + +The final output of the lidar point cloud visualization is shown below: + + +### 3D Visualization Example of BEV +The 3D visualization files for LiDAR BEV are located in `demo/visualization_demo/`. The folder provides two examples: `bev_vis_single_frame_demo.py` for visualizing a single frame and `bev_vis_multi_frame_demo.py` for visualizing multiple frames. Both examples use the same visualization interface, which is defined in `paddle3d.apis.infer`. + +The implementation methods of `bev_vis_single_frame_demo.py` and `bev_vis_multi_frame_demo.py` are different to provide more visualization options. In `bev_vis_single_frame_demo.py`, visualization is achieved using the inference deployment with Paddle, while `bev_vis_multi_frame_demo.py` performs sequential frame reading and inference by constructing a dataloader for visualization. + +To use `bev_vis_single_frame_demo.py`, use the following command: + +``` +cd demo/visualization_demo +python bev_vis_single_frame_demo.py \ + --model_file model/pointpillars.pdmodel \ + --params_file model/pointpillars.pdiparams \ + --lidar_file data/velodyne/000008.bin \ + --calib_file data/calib/000008.txt \ + --point_cloud_range 0 -39.68 -3 69.12 39.68 1 \ + --voxel_size .16 .16 4 \ + --max_points_in_voxel 32 \ + --max_voxel_num 40000 +``` +`--model_file` and `--params_file` are the paths to the model parameter files being used. + +`--lidar_file` and `--calib_file` are the paths to the LiDAR point cloud and its corresponding calibration file. + +`--point_cloud_range` represents the range of the LiDAR point cloud in `(x, y, z)` coordinates. + +`--voxel_size` represents the size of the voxels used in processing. + +`--max_points_in_voxel` is the maximum number of LiDAR points in each voxel. + +`--max_voxel_num` is the maximum number of voxels. + + +To use `bev_vis_multi_frame_demo.py`, use the following command: + +``` +python bev_vis_multi_frame_demo.py \ + --config configs/pointpillars/pointpillars_xyres16_kitti_car.yml \ + --model demo/pointpillars.pdparams \ + --batch_size 1 +``` + +`--config` is the path to the model configuration file. + +`--model` is the path to the model parameter file being used. + +`--batch_size` is the batch size for inference. + +The final output of the lidar BEV visualization is shown below: + + + +### Visualization interface for datasets and LOG files +The code of the visualization interface is defined in `paddle3d.apis.infer`, we provide an example as follows: + +``` +cd demo/visualization_demo +python dataset_vis_demo.py +``` + +--- +If you encounter the following problems, refer to Ref1 and Ref2 solutions: + +`qt.qpa.plugin: Could not load the Qt Platform plugin 'xcb' in ..` + +[Ref1](https://blog.csdn.net/qq_39938666/article/details/120452028?spm=1001.2101.3001.6650.2&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-120452028-blog-112303826.pc_relevant_3mothn_strategy_recovery&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-120452028-blog-112303826.pc_relevant_3mothn_strategy_recovery&utm_relevant_index=3) +& [Ref2](https://blog.csdn.net/weixin_41794514/article/details/128578166?spm=1001.2101.3001.6650.3&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-128578166-blog-119480436.pc_relevant_landingrelevant&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-128578166-blog-119480436.pc_relevant_landingrelevant) diff --git a/paddle3d/apis/infer.py b/paddle3d/apis/infer.py new file mode 100644 index 00000000..fc9a0472 --- /dev/null +++ b/paddle3d/apis/infer.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import cv2 +import sys +import numpy as np +from collections import defaultdict +from typing import Callable, Optional, Union + +import paddle +from visualdl import LogWriter + +import paddle3d.env as env +from paddle3d.apis.checkpoint import Checkpoint, CheckpointABC +from paddle3d.apis.pipeline import training_step, validation_step +from paddle3d.apis.scheduler import Scheduler, SchedulerABC +from paddle3d.utils.logger import logger +from paddle3d.utils.shm_utils import _get_shared_memory_size_in_M +from paddle3d.utils.timer import Timer + +from paddle3d.datasets.kitti.kitti_utils import camera_record_to_object + +from paddle3d.apis.trainer import Trainer + +from demo.utils import Calibration, show_lidar_with_boxes, total_imgpred_by_conf_to_kitti_records, \ + make_imgpts_list, draw_mono_3d, show_bev_with_boxes + + +class Infer(Trainer): + """ + """ + + def __init__( + self, + model: paddle.nn.Layer, + optimizer: paddle.optimizer.Optimizer, + iters: Optional[int] = None, + epochs: Optional[int] = None, + train_dataset: Optional[paddle.io.Dataset] = None, + val_dataset: Optional[paddle.io.Dataset] = None, + resume: bool = False, + # TODO: Default parameters should not use mutable objects, there is a risk + checkpoint: Union[dict, CheckpointABC] = dict(), + scheduler: Union[dict, SchedulerABC] = dict(), + dataloader_fn: Union[dict, Callable] = dict(), + amp_cfg: Optional[dict] = None): + super(Infer, self).__init__( + model, optimizer, iters, epochs, train_dataset, val_dataset, resume, + checkpoint, scheduler, dataloader_fn, amp_cfg) + + def infer(self, mode) -> float: + """ + """ + sync_bn = (getattr(self.model, 'sync_bn', False) and env.nranks > 1) + if sync_bn: + sparse_conv = False + for layer in self.model.sublayers(): + if 'sparse' in str(type(layer)): + sparse_conv = True + break + if sparse_conv: + self.model = paddle.sparse.nn.SyncBatchNorm.convert_sync_batchnorm( + self.model) + else: + self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( + self.model) + + if self.val_dataset is None: + raise RuntimeError('No evaluation dataset specified!') + msg = 'evaluate on validate dataset' + metric_obj = self.val_dataset.metric + + for idx, sample in logger.enumerate(self.eval_dataloader, msg=msg): + results = validation_step(self.model, sample) + + if mode == 'pcd': + for result in results: + scan = np.fromfile(result['path'], dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + # Obtain calibration information about Kitti + calib = Calibration(result['path'].replace( + 'velodyne', 'calib').replace('bin', 'txt')) + # Plot box in lidar cloud + # show_lidar_with_boxes(pc_velo, result['bboxes_3d'], result['confidences'], calib) + show_lidar_with_boxes(pc_velo, result['bboxes_3d'], + result['confidences'], calib) + + if mode == 'image': + for result in results: + kitti_records = total_imgpred_by_conf_to_kitti_records( + result, 0.3) + bboxes_2d, bboxes_3d, labels = camera_record_to_object( + kitti_records) + # read origin image + img_origin = cv2.imread(result['path']) + # to 8 points on image + K = np.array(result['meta']['camera_intrinsic']) + imgpts_list = make_imgpts_list(bboxes_3d, K) + # draw smoke result to photo + draw_mono_3d(img_origin, imgpts_list) + + if mode == 'bev': + for result in results: + scan = np.fromfile(result['path'], dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + # Obtain calibration information about Kitti + calib = Calibration(result['path'].replace( + 'velodyne', 'calib').replace('bin', 'txt')) + # Plot box in lidar cloud (bev) + show_bev_with_boxes(pc_velo, result['bboxes_3d'], + result['confidences'], calib) diff --git a/requirements.txt b/requirements.txt index 99c41a91..a78e016d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,7 @@ scikit-image scikit-learn visualdl h5py +einops +vtk==8.1.2 +mayavi==4.7.4 +PyQt5