Pytorch is not built for the best efficiency. If you're someone who wants maximum efficiency in your distributed training setup then you need something much more efficient than pytorch. This is where CUTLASS kernels comes in play by providing efficient cuda kernels to get truely maximum efficiency in Matmuls and other performance-heavy operations
This project further simplifies the complicated distributed gemm operation by wrapping the logic from the official CUTLASS example 82_blackwell_distributed_gemm into a custom PyTorch C++ extension.
Note that this project also have whole cutlass repository as a dependency and it will be added through a git submodule
- 8x NVIDIA Blackwell (Compute Capability 10.0+) or newer GPUs with NVLink
- CUDA Toolkit 12.8 or newer
- PyTorch (built with a compatible CUDA version)
This project includes CUTLASS as a git submodule
. The setup process is straightforward:
1. Clone the repository recursively to pull in CUTLASS:
git clone --recursive https://prateekshukla1108/pytorch-distributed-gemm.git
cd pytorch-distributed-gemm
If you cloned without --recursive
, you can initialize the submodule manually:
git submodule update --init --recursive
2. Install the PyTorch extension:
This command will compile the C++/CUDA source file and install the operator in your current Python environment.
pip install .
The installation process may take a few minutes as it compiles the CUTLASS kernels.
.
├── cutlass/ # CUTLASS git submodule
├── src/
│ └── dist_gemm_ext.cu # PyTorch C++/CUDA binding
├── setup.py # Build script for the extension
└── README.md # This file
After installation, you can import and use the dist_matmul
function. The operator expects three contiguous input tensors (A
, B
, C
) and returns the result of alpha * (A @ B_T) + beta * C
.
Note on Tensor Layouts:
A
is expected to have shape(M, K)
.B
is expected to have shape(N, K)
and be in a row-major layout. This is different from the standardtorch.matmul
which expectsB
to be(K, N)
.C
is the bias tensor of shape(M, N)
.
Note that this does not have autograd support, uses cutlass::float_e4m3_t as the data type which is not the datatype which pytorch interface accepts so ensure memory layout compatibility or implement type conversions. Also to change the tensor Parallelism size you would need to change dist_gemm_ext.cu and reinstall the package