-
Notifications
You must be signed in to change notification settings - Fork 458
Description
Describe the bug
When using TE pytorch layers within an fp8 context, during the backward pass an error is raised for some GEMM operations:
loss.backward()
File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 648, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 353, in backward
_engine_run_backward(
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/linear.py", line 788, in backward
wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/linear.py", line 768, in wgrad_gemm
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 105, in general_gemm
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /workspace/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:537 in function cublas_gemm: Assertion failed: status != CUBLAS_STATUS_NOT_SUPPORTED. Unable to find suitable cuBLAS GEMM algorithm
I see the error raised when using te.Linear, and also te.LayerNormLinear. The data shape going into these layers is typically [1, N, 256] where N is large-ish, 300k or more.
Steps/Code to reproduce bug
This is seen when running the Transolver model in NVIDIA physicsnemo. I can share reproducers but its in-development code. I haven't stripped it down to a simple reproducer but might be able to, if needed.
Expected behavior
I expected CUBLAS to have a suitable GEMM algorithm.
Environment overview (please complete the following information)
Running with the docker container for pytorch25.06 from NGC; no modifications to pytorch or transformer engine.
Environment details
NVIDIA docker container is used :)
Device details
H100s on EOS (internal machines to nvidia)
Additional context