Skip to content

CUBLAS unable to find suitable GEMM in backwards pass #1969

@coreyjadams

Description

@coreyjadams

Describe the bug

When using TE pytorch layers within an fp8 context, during the backward pass an error is raised for some GEMM operations:

    loss.backward()
  File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 353, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/linear.py", line 788, in backward
    wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/linear.py", line 768, in wgrad_gemm
    dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 105, in general_gemm
    out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /workspace/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:537 in function cublas_gemm: Assertion failed: status != CUBLAS_STATUS_NOT_SUPPORTED. Unable to find suitable cuBLAS GEMM algorithm

I see the error raised when using te.Linear, and also te.LayerNormLinear. The data shape going into these layers is typically [1, N, 256] where N is large-ish, 300k or more.

Steps/Code to reproduce bug

This is seen when running the Transolver model in NVIDIA physicsnemo. I can share reproducers but its in-development code. I haven't stripped it down to a simple reproducer but might be able to, if needed.

Expected behavior

I expected CUBLAS to have a suitable GEMM algorithm.

Environment overview (please complete the following information)

Running with the docker container for pytorch25.06 from NGC; no modifications to pytorch or transformer engine.

Environment details

NVIDIA docker container is used :)

Device details

H100s on EOS (internal machines to nvidia)

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions