Skip to content

"Internal Error: MyelinCheckException: gvn.cpp:318: CHECK(graph().ssa_validation()) failed." when building engine #4468

Open
@xjy1995

Description

@xjy1995

Description

I encountered the following error when trying to build trt engine from onnx model.

 110947: reshape: /Reshape_32 _ /Transpose_17_reshape_output.1-(f16[__mye111018_proxy.1,1,1025,2,801,2][]so[], mem_prop=0) | /Concat_47_output_0' castIn.1-(f16[__mye111018_proxy.1,1,2050,801,2][]so[], mem_prop=0), stream = 0 // /Reshape_32 _ /Transpose_17_reshape | shape: {__mye111018_proxy.1, 1, 1025, 2, 801, 2} 
 111002: transpose: out'.1-(f16[__mye111018_proxy.1,1,2,1025,801,2][3306528,3306528,1653264,1,2064,1032]so[5,4,3,0,2,1]p[0,0,0,0,0,7], mem_prop=0) | /Reshape_32 _ /Transpose_17_reshape_output.1-(f16[__mye111018_proxy.1,1,1025,2,801,2][]so[], mem_prop=0), stream = 0 // shuffle_output_/Reshape_32 _ /Transpose_17_first_transpose | perm: {0, 1, 3, 2, 4, 5} 
 111010: exit:  | out'.1-(f16[__mye111018_proxy.1,1,2,1025,801,2][3306528,3306528,1653264,1,2064,1032]so[5,4,3,0,2,1]p[0,0,0,0,0,7], mem_prop=0), stream = 0 // __mye111009
 B1: succ[]
 
 Total operations: 4391
791 vs. actual uses: 784
Internal Error: MyelinCheckException: gvn.cpp:318: CHECK(graph().ssa_validation()) failed. 
[05/27/2025-14:27:16] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [myelin_graph.h:attachExceptionMsgToGraph:866] MyelinCheckException: gvn.cpp:318: CHECK(graph().ssa_validation()) failed. 
[05/27/2025-14:27:16] [TRT] [V] {ForeignNode[ONNXTRT_castHelper_10786_output[Constant].../Reshape_32 + /Transpose_17]} (Myelin[0x80000023]) profiling completed in 14.7534 seconds. Fastest Tactic: 0xd15ea5edd15ea5ed Time: inf
[05/27/2025-14:27:16] [TRT] [E] IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[ONNXTRT_castHelper_10786_output[Constant].../Reshape_32 + /Transpose_17]}.)

Command that successfully created engine using the same build script:
python3 build_engine.py --onnx=model.onnx --minShapes=x:1x2x1025x801x2 --optShapes=x:1x2x1025x801x2 --maxShapes=x:1x2x1025x801x2 --fp16 --engine=out.trt

Command that failed:
python3 build_engine.py --onnx=model.onnx --minShapes=x:1x2x1025x801x2 --optShapes=x:2x2x1025x801x2 --maxShapes=x:2x2x1025x801x2 --fp16 --engine=out.trt

Thank you in advance for your help.

Environment

TensorRT Version: 10.9.0

NVIDIA GPU: A10

NVIDIA Driver Version: 535.161.08

CUDA Version: 12.8

CUDNN Version:

Operating System: Debian 12

Python Version (if applicable): 3.11.2

Tensorflow Version (if applicable):

PyTorch Version (if applicable):

Baremetal or Container (if so, version):

Relevant Files

Model link: https://www.dropbox.com/scl/fi/uo6ltm5slqrguum9ht3va/model.onnx?rlkey=exkpe3fi3d6gse29f7oocha6v&st=dtwf5nrn&dl=0

Full Log:

log.txt

Steps To Reproduce

script used:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import argparse
import tensorrt as trt
import numpy as np
import onnx
from typing import Dict, List, Tuple

TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)

def build_engine(
    onnx_file: str,
    min_shapes: Dict[str, List[int]],
    opt_shapes: Dict[str, List[int]],
    max_shapes: Dict[str, List[int]],
    engine_file: str,
    fp16_mode: bool = True,
):
    """
    构建TensorRT engine
    
    参数:
        onnx_file: ONNX模型文件路径
        min_shapes: 最小形状 {input_name: shape}
        opt_shapes: 最优形状 {input_name: shape}
        max_shapes: 最大形状 {input_name: shape}
        engine_file: 输出的engine文件路径
        fp16_mode: 是否启用FP16模式
    """
    print(f"正在从ONNX模型构建TensorRT engine: {onnx_file}")
    
    # 加载ONNX模型检查精度信息
    onnx_model = onnx.load(onnx_file)
    
    # 初始化TensorRT
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()
    config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED

    # 设置动态形状
    profile = builder.create_optimization_profile()
    for input_name, min_shape in min_shapes.items():
        profile.set_shape(input_name, min_shape, opt_shapes[input_name], max_shapes[input_name])
        print(f"设置输入 {input_name} 的动态形状: min={min_shape}, opt={opt_shapes[input_name]}, max={max_shapes[input_name]}")
    
    config.add_optimization_profile(profile)

    # 从ONNX解析网络
    parser = trt.OnnxParser(network, TRT_LOGGER)
    with open(onnx_file, 'rb') as model:
        if not parser.parse(model.read()):
            for error in range(parser.num_errors):
                print(f"ONNX解析错误: {parser.get_error(error)}")
            raise RuntimeError("ONNX解析失败")
    
    print("ONNX模型解析成功")
    
    if fp16_mode:
        config.set_flag(trt.BuilderFlag.FP16)
        config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
    
        for i in range(network.num_layers):
            layer = network.get_layer(i)
            
            def should_keep_fp16(layer):
                result = False
                if layer.type == trt.LayerType.MATRIX_MULTIPLY:
                    print(f"找到MatMul层: {layer.name}")
                    result = True
                if '/Add' in layer.name:
                    print(f"找到Add层: {layer.name}")
                    result = True
                if '/Mul' in layer.name:
                    print(f"找到Mul层: {layer.name}")
                    result = True
                return result
            
            layer_name = layer.name
            layer_type = layer.type
            
            if not should_keep_fp16(layer):
                is_fp32_compatible = True
                
                for j in range(layer.num_outputs):
                    if (layer.get_output_type(j) == trt.int64):
                        is_fp32_compatible = False
                        break
                
                if is_fp32_compatible:
                    layer.precision = trt.float32

    
    serialized_engine = builder.build_serialized_network(network, config)
    if serialized_engine is None:
        raise RuntimeError("引擎构建失败")
    
    with open(engine_file, 'wb') as f:
        f.write(serialized_engine)
    
    print(f"TensorRT engine已保存至: {engine_file}")
    return serialized_engine

def parse_shape(shape_str: str) -> List[int]:
    return [int(dim) for dim in shape_str.split('x')]

def main():
    parser = argparse.ArgumentParser(description="构建TensorRT engine")
    parser.add_argument("--onnx", required=True, help="ONNX模型文件路径")
    parser.add_argument("--minShapes", required=True, help="最小形状,格式: 'input_name:dim1xdim2x...'")
    parser.add_argument("--optShapes", required=True, help="最优形状,格式: 'input_name:dim1xdim2x...'")
    parser.add_argument("--maxShapes", required=True, help="最大形状,格式: 'input_name:dim1xdim2x...'")
    parser.add_argument("--engine", default=None, help="输出的engine文件路径")
    parser.add_argument("--fp16", action="store_true", help="是否启用FP16模式")
    
    args = parser.parse_args()
    
    def parse_shapes_arg(shape_arg: str) -> Dict[str, List[int]]:
        name, shape_str = shape_arg.split(':', 1)
        return {name: parse_shape(shape_str)}
    
    min_shapes = parse_shapes_arg(args.minShapes)
    opt_shapes = parse_shapes_arg(args.optShapes)
    max_shapes = parse_shapes_arg(args.maxShapes)
    
    if args.engine is None:
        args.engine = os.path.splitext(args.onnx)[0] + ".trt"
    
    build_engine(
        args.onnx,
        min_shapes,
        opt_shapes,
        max_shapes,
        args.engine,
        args.fp16,
    )

if __name__ == "__main__":
    main() 

Metadata

Metadata

Assignees

Labels

Module:Engine BuildIssues with building TensorRT enginestriagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions