Skip to content

[Bug] Mask2Former convert to tensorRT model failed when setting batchSize > 1 #2920

@yzhou-bcom

Description

@yzhou-bcom

Checklist

  • I have searched related issues but cannot get the expected help.
  • 2. I have read the FAQ documentation but cannot get the expected help.
  • 3. The bug has not been fixed in the latest version.

Describe the bug

Hello, I tried to convert Mask2Former pytorch model to TensorRT engine file with dynamic input size using this config panoptic-seg_maskformer_tensorrt_dynamic-320x512-1344x1344.py, it worked. But then I tried to set dynamic batch size. I modified the backend _config in this config file as follows:

backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 512],
opt_shape=[2, 3, 800, 1344],
max_shape=[4, 3, 1344, 1344])))
]),

The conversion crashed with error message " [shapeCompiler.cpp::evaluateShapeChecks::1276] Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: IShuffleLayer /backbone/stages.0/blocks.0/attn/Reshape_7: reshaping failed for tensor: /backbone/stages.0/blocks.0/attn/Transpose_3_output_0 reshape would change volume 13095936 to 6547968)".

I would like to know if there is a way to fix it ? I really want to have a model which can take multiple images as input.
Many thanks.

Reproduction

python ~/mmdeploy/tools/deploy.py
~/mmdeploy/configs/mmdet/panoptic-seg/panoptic-seg_maskformer_tensorrt_dynamic-320x512-1344x1344.py mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco.py mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco_20220504_001756-c9d0c4f2.pth
~/mmdeploy/demo/ressources/det.jpg
--work-dir ./mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco_deploy/
--device cuda --show --dump-info

Environment

10/02 11:17:33 - mmengine - INFO - **********Environmental information**********
/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import \
10/02 11:17:35 - mmengine - INFO - sys.platform: linux
10/02 11:17:35 - mmengine - INFO - Python: 3.8.20 (default, Oct  3 2024, 15:24:27) [GCC 11.2.0]
10/02 11:17:35 - mmengine - INFO - CUDA available: True
10/02 11:17:35 - mmengine - INFO - MUSA available: False
10/02 11:17:35 - mmengine - INFO - numpy_random_seed: 2147483648
10/02 11:17:35 - mmengine - INFO - GPU 0: Quadro RTX 3000
10/02 11:17:35 - mmengine - INFO - CUDA_HOME: /usr/local/cuda-11.8/
10/02 11:17:35 - mmengine - INFO - NVCC: Cuda compilation tools, release 11.8, V11.8.89
10/02 11:17:35 - mmengine - INFO - GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
10/02 11:17:35 - mmengine - INFO - PyTorch: 2.4.1+cu118
10/02 11:17:35 - mmengine - INFO - PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.4.2 (Git Hash 1137e04ec0b5251ca2b4400a4fd3c667ce843d67)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_90,code=sm_90
  - CuDNN 90.1
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=9.1.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.4.1, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

10/02 11:17:35 - mmengine - INFO - TorchVision: 0.19.1+cu118
10/02 11:17:35 - mmengine - INFO - OpenCV: 4.12.0
10/02 11:17:35 - mmengine - INFO - MMEngine: 0.10.7
10/02 11:17:35 - mmengine - INFO - MMCV: 2.1.0
10/02 11:17:35 - mmengine - INFO - MMCV Compiler: GCC 9.4
10/02 11:17:35 - mmengine - INFO - MMCV CUDA Compiler: 11.8
10/02 11:17:35 - mmengine - INFO - MMDeploy: 1.3.1+3f8604b
10/02 11:17:35 - mmengine - INFO -

10/02 11:17:35 - mmengine - INFO - **********Backend information**********
10/02 11:17:35 - mmengine - INFO - tensorrt:    8.6.1
10/02 11:17:35 - mmengine - INFO - tensorrt custom ops: Available
10/02 11:17:35 - mmengine - INFO - ONNXRuntime: 1.8.1
10/02 11:17:35 - mmengine - INFO - ONNXRuntime-gpu:     None
10/02 11:17:35 - mmengine - INFO - ONNXRuntime custom ops:      Available
10/02 11:17:35 - mmengine - INFO - pplnn:       None
10/02 11:17:35 - mmengine - INFO - ncnn:        None
10/02 11:17:35 - mmengine - INFO - snpe:        None
10/02 11:17:35 - mmengine - INFO - openvino:    None
10/02 11:17:35 - mmengine - INFO - torchscript: 2.4.1+cu118
10/02 11:17:35 - mmengine - INFO - torchscript custom ops:      NotAvailable
10/02 11:17:35 - mmengine - INFO - rknn-toolkit:        None
10/02 11:17:35 - mmengine - INFO - rknn-toolkit2:       None
10/02 11:17:35 - mmengine - INFO - ascend:      None
10/02 11:17:35 - mmengine - INFO - coreml:      None
10/02 11:17:35 - mmengine - INFO - tvm: None
10/02 11:17:35 - mmengine - INFO - vacc:        None
10/02 11:17:35 - mmengine - INFO -

10/02 11:17:35 - mmengine - INFO - **********Codebase information**********
10/02 11:17:35 - mmengine - INFO - mmdet:       3.3.0
10/02 11:17:35 - mmengine - INFO - mmseg:       1.2.2
10/02 11:17:35 - mmengine - INFO - mmpretrain:  None
10/02 11:17:35 - mmengine - INFO - mmocr:       None
10/02 11:17:35 - mmengine - INFO - mmagic:      None
10/02 11:17:35 - mmengine - INFO - mmdet3d:     None
10/02 11:17:35 - mmengine - INFO - mmpose:      None
10/02 11:17:35 - mmengine - INFO - mmrotate:    None
10/02 11:17:35 - mmengine - INFO - mmaction:    None
10/02 11:17:35 - mmengine - INFO - mmrazor:     None
10/02 11:17:35 - mmengine - INFO - mmyolo:      None

Error traceback

from torch.distributed.optim import \
10/02 11:07:47 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
10/02 11:07:47 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "mmdet_tasks" registry tree. As a workaround, the current "mmdet_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import \
10/02 11:07:50 - mmengine - INFO - Start pipeline mmdeploy.apis.pytorch2onnx.torch2onnx in subprocess
10/02 11:07:51 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
10/02 11:07:51 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "mmdet_tasks" registry tree. As a workaround, the current "mmdet_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
Loads checkpoint by local backend from path: /mnt/d/SolAR/Data/OpenMMLab/pretrained_models/mask2former/instance/mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco_20220504_001756-c9d0c4f2.pth
/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/site-packages/mmengine/runner/checkpoint.py:347: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(filename, map_location=map_location)
10/02 11:07:53 - mmengine - WARNING - DeprecationWarning: get_onnx_config will be deprecated in the future.
10/02 11:07:53 - mmengine - INFO - Export PyTorch model to ONNX: /mnt/d/SolAR/Data/OpenMMLab/work_dirs/mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco_deploy/end2end.onnx.
10/02 11:07:53 - mmengine - WARNING - Can not find torch.nn.functional._scaled_dot_product_attention, function rewrite will not be applied
10/02 11:07:53 - mmengine - WARNING - Can not find mmdet.models.utils.transformer.PatchMerging.forward, function rewrite will not be applied
/home/yzhou/mmdetection/mmdet/models/layers/transformer/utils.py:167: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  output_h = math.ceil(input_h / stride_h)
/home/yzhou/mmdetection/mmdet/models/layers/transformer/utils.py:168: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  output_w = math.ceil(input_w / stride_w)
/home/yzhou/mmdetection/mmdet/models/layers/transformer/utils.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  pad_h = max((output_h - 1) * stride_h +
/home/yzhou/mmdetection/mmdet/models/layers/transformer/utils.py:171: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  pad_w = max((output_w - 1) * stride_w +
/home/yzhou/mmdetection/mmdet/models/layers/transformer/utils.py:177: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if pad_h > 0 or pad_w > 0:
/home/yzhou/mmdeploy/mmdeploy/codebase/mmdet/models/backbones.py:189: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert L == H * W, 'input feature has wrong size'
/home/yzhou/mmdeploy/mmdeploy/codebase/mmdet/models/backbones.py:147: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  B = int(windows.shape[0] / (H * W / window_size / window_size))
/home/yzhou/mmdetection/mmdet/models/layers/transformer/utils.py:414: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert L == H * W, 'input feature has wrong size'
/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/site-packages/torch/functional.py:513: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3609.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/site-packages/mmcv/ops/multi_scale_deform_attn.py:335: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/site-packages/mmcv/ops/multi_scale_deform_attn.py:351: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if reference_points.shape[-1] == 2:
/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/site-packages/mmcv/ops/multi_scale_deform_attn.py:136: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/site-packages/mmcv/ops/multi_scale_deform_attn.py:140: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for level, (H_, W_) in enumerate(value_spatial_shapes):
[W1002 11:08:03.719314495 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.730073130 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.741138972 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.769611594 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.782496876 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.793162709 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.822239844 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.833910000 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.844908540 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.875648612 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.889014104 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.901121269 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.930752316 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.943638898 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.956234273 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.987100148 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.998061187 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
[W1002 11:08:03.009634240 shape_type_inference.cpp:1998] Warning: The shape inference of mmdeploy::grid_sampler type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)
10/02 11:08:07 - mmengine - INFO - Execute onnx optimize passes.
10/02 11:08:09 - mmengine - INFO - Finish pipeline mmdeploy.apis.pytorch2onnx.torch2onnx
/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import \
10/02 11:08:12 - mmengine - INFO - Start pipeline mmdeploy.apis.utils.utils.to_backend in subprocess
10/02 11:08:12 - mmengine - INFO - Successfully loaded tensorrt plugins from /home/yzhou/mmdeploy/mmdeploy/lib/libmmdeploy_tensorrt_ops.so
[10/02/2025-11:08:12] [TRT] [I] [MemUsageChange] Init CUDA: CPU +14, GPU +0, now: CPU 137, GPU 1036 (MiB)
[10/02/2025-11:08:18] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +903, GPU +174, now: CPU 1117, GPU 1210 (MiB)
[10/02/2025-11:08:22] [TRT] [I] ----------------------------------------------------------------
[10/02/2025-11:08:22] [TRT] [I] Input filename:   /mnt/d/SolAR/Data/OpenMMLab/work_dirs/mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco_deploy/end2end.onnx
[10/02/2025-11:08:22] [TRT] [I] ONNX IR version:  0.0.7
[10/02/2025-11:08:22] [TRT] [I] Opset version:    13
[10/02/2025-11:08:22] [TRT] [I] Producer name:    pytorch
[10/02/2025-11:08:22] [TRT] [I] Producer version: 2.4.1
[10/02/2025-11:08:22] [TRT] [I] Domain:
[10/02/2025-11:08:22] [TRT] [I] Model version:    0
[10/02/2025-11:08:22] [TRT] [I] Doc string:
[10/02/2025-11:08:22] [TRT] [I] ----------------------------------------------------------------
[10/02/2025-11:08:23] [TRT] [W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[10/02/2025-11:08:23] [TRT] [W] onnx2trt_utils.cpp:400: One or more weights outside the range of INT32 was clamped
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: TRTInstanceNormalization. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: TRTInstanceNormalization, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: TRTInstanceNormalization
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: TRTInstanceNormalization. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: TRTInstanceNormalization, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: TRTInstanceNormalization
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: TRTInstanceNormalization. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: TRTInstanceNormalization, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: TRTInstanceNormalization
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: grid_sampler. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: grid_sampler, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: grid_sampler
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: TRTInstanceNormalization. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: TRTInstanceNormalization, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: TRTInstanceNormalization
[10/02/2025-11:08:23] [TRT] [I] No importer registered for op: TRTInstanceNormalization. Attempting to import as plugin.
[10/02/2025-11:08:23] [TRT] [I] Searching for plugin: TRTInstanceNormalization, plugin_version: 1, plugin_namespace:
[10/02/2025-11:08:23] [TRT] [I] Successfully created plugin: TRTInstanceNormalization
[10/02/2025-11:08:23] [TRT] [I] BuilderFlag::kTF32 is set but hardware does not support TF32. Disabling TF32.
[10/02/2025-11:08:24] [TRT] [I] Graph optimization time: 0.338246 seconds.
[10/02/2025-11:08:24] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +6, GPU +10, now: CPU 1462, GPU 1220 (MiB)
[10/02/2025-11:08:24] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +1, GPU +8, now: CPU 1463, GPU 1228 (MiB)
[10/02/2025-11:08:24] [TRT] [I] BuilderFlag::kTF32 is set but hardware does not support TF32. Disabling TF32.
[10/02/2025-11:08:24] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[10/02/2025-11:08:24] [TRT] [E] 4: kOPT values for profile 0 violate shape constraints: IShuffleLayer /backbone/stages.0/blocks.0/attn/Reshape_7: reshaping failed for tensor: /backbone/stages.0/blocks.0/attn/Transpose_3_output_0 reshape would change volume 13095936 to 6547968
[10/02/2025-11:08:24] [TRT] [E] 4: [shapeCompiler.cpp::evaluateShapeChecks::1276] Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: IShuffleLayer /backbone/stages.0/blocks.0/attn/Reshape_7: reshaping failed for tensor: /backbone/stages.0/blocks.0/attn/Transpose_3_output_0 reshape would change volume 13095936 to 6547968)
Process Process-3:
Traceback (most recent call last):
  File "/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/yzhou/miniconda3/envs/openmmlab/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/yzhou/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in __call__
    ret = func(*args, **kwargs)
  File "/home/yzhou/mmdeploy/mmdeploy/apis/utils/utils.py", line 98, in to_backend
    return backend_mgr.to_backend(
  File "/home/yzhou/mmdeploy/mmdeploy/backend/tensorrt/backend_manager.py", line 127, in to_backend
    onnx2tensorrt(
  File "/home/yzhou/mmdeploy/mmdeploy/backend/tensorrt/onnx2tensorrt.py", line 79, in onnx2tensorrt
    from_onnx(
  File "/home/yzhou/mmdeploy/mmdeploy/backend/tensorrt/utils.py", line 248, in from_onnx
    assert engine is not None, 'Failed to create TensorRT engine'
AssertionError: Failed to create TensorRT engine
10/02 11:08:25 - mmengine - ERROR - /home/yzhou/mmdeploy/mmdeploy/apis/core/pipeline_manager.py - pop_mp_output - 80 - `mmdeploy.apis.utils.utils.to_backend` with Call id: 1 failed. exit.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions