Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

导出Llama-2-7b-chat-ms模型的时候,concat部分报错 #29

Open
mi-tao opened this issue Feb 5, 2024 · 1 comment
Open

导出Llama-2-7b-chat-ms模型的时候,concat部分报错 #29

mi-tao opened this issue Feb 5, 2024 · 1 comment

Comments

@mi-tao
Copy link

mi-tao commented Feb 5, 2024

报错提示

发生异常: AssertionError (note: full exception trace is shown but execution is paused at: _run_module_as_main) exception: no description File "/usr/local/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py", line 539, in cat assert all( File "/usr/local/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper return fn(g, *args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/onnx/symbolic_opset11.py", line 551, in cat return opset9.cat(g, tensor_list, dim) File "/usr/local/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 392, in wrapper return fn(g, *args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1891, in _run_symbolic_function return symbolic_fn(graph_context, *inputs, **attrs) File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 665, in _optimize_graph graph = _C._jit_pass_onnx(graph, operator_export_type) File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph graph = _optimize_graph( File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1548, in _export graph, params_dict, torch_out = _model_to_graph( File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 506, in export _export( File "/work/hu/alg_sources/llm-export/llm_export.py", line 228, in export_block torch.onnx.export( File "/work/hu/alg_sources/llm-export/llm_export.py", line 251, in export_blocks self.export_block(i) File "/work/hu/alg_sources/llm-export/llm_export.py", line 868, in <module> llm_exporter.export_blocks() File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame) return _run_code(code, main_globals, None, AssertionError:

报错代码位置

@_onnx_symbolic("aten::cat")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def cat(g: jit_utils.GraphContext, tensor_list, dim):
    tensors = symbolic_helper._unpack_list(tensor_list)
    # torch.cat ignores empty tensors such as `torch.Tensor([])`
    # These needs to be removed as input from ONNX's concat too, otherwise shape inference
    # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else)
    nonempty_tensors = []
    for t in tensors:
        if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size(
            t, 0
        ):
            continue
        nonempty_tensors.append(t)
    assert len(nonempty_tensors) > 0
    assert all(
        [
            symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None
            or symbolic_helper._get_tensor_rank(t) is None
            or symbolic_helper._get_tensor_rank(t)
            == symbolic_helper._get_tensor_rank(nonempty_tensors[0])
            for t in nonempty_tensors
        ]
    )
    tensor_list.node().removeAllInputs()
    for t in nonempty_tensors:
        tensor_list.node().addInput(t)

    tensors = symbolic_helper._unpack_list(tensor_list)
    return g.op("Concat", *tensors, axis_i=dim)

情况

nonempty_tensors的两个tensor的shape分别为【1,32,0,128】【1,1,3,32,3,128】,不能够cat
Llama-2-7b-chat-ms模型下载地址: https://modelscope.cn/models/modelscope/Llama-2-7b-ms/files
python环境:

certifi                       2023.11.17
charset-normalizer            3.3.2
cmake                         3.28.1
coloredlogs                   15.0.1
filelock                      3.13.1
flatbuffers                   23.5.26
fsspec                        2023.12.2
huggingface-hub               0.20.1
humanfriendly                 10.0
idna                          3.6
Jinja2                        3.1.2
lit                           17.0.6
markdown-it-py                3.0.0
MarkupSafe                    2.1.3
mdurl                         0.1.2
mpmath                        1.3.0
networkx                      3.2.1
numpy                         1.25.2
nvidia-cublas-cu11            11.10.3.66
nvidia-cublas-cu12            12.1.3.1
nvidia-cuda-cupti-cu11        11.7.101
nvidia-cuda-cupti-cu12        12.1.105
nvidia-cuda-nvrtc-cu11        11.7.99
nvidia-cuda-nvrtc-cu12        12.1.105
nvidia-cuda-runtime-cu11      11.7.99
nvidia-cuda-runtime-cu12      12.1.105
nvidia-cudnn-cu11             8.5.0.96
nvidia-cudnn-cu12             8.9.2.26
nvidia-cufft-cu11             10.9.0.58
nvidia-cufft-cu12             11.0.2.54
nvidia-curand-cu11            10.2.10.91
nvidia-curand-cu12            10.3.2.106
nvidia-cusolver-cu11          11.4.0.1
nvidia-cusolver-cu12          11.4.5.107
nvidia-cusparse-cu11          11.7.4.91
nvidia-cusparse-cu12          12.1.0.106
nvidia-nccl-cu11              2.14.3
nvidia-nccl-cu12              2.18.1
nvidia-nvjitlink-cu12         12.3.101
nvidia-nvtx-cu11              11.7.91
nvidia-nvtx-cu12              12.1.105
onnx                          1.15.0
onnxruntime                   1.15.1
onnxsim                       0.4.35
packaging                     23.2
pip                           23.3.2
protobuf                      4.25.1
Pygments                      2.17.2
PyYAML                        6.0.1
regex                         2023.12.25
requests                      2.31.0
rich                          13.7.0
safetensors                   0.4.1
sentencepiece                 0.1.99
setuptools                    57.4.0
sympy                         1.12
tabulate                      0.9.0
tokenizers                    0.13.3
torch                         2.0.1
tqdm                          4.66.1
transformers                  4.31.0
transformers-stream-generator 0.0.4
triton                        2.0.0
typing_extensions             4.9.0
urllib3                       2.1.0
wheel                         0.42.0
@mi-tao
Copy link
Author

mi-tao commented Feb 18, 2024

通过和原始的modeling_llama.py文件比较,定位被修改的地方
image
通过改回两个 squeeze,导出成功
打断点调试 torch.squeez在前向中正常降低维度成[seq, dim],目前未搞清楚导致问题的原因

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant