Skip to content

prateekshukla1108/pytorch-distributed-gemm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pytorch Op for blackwell distributed gemm

Pytorch is not built for the best efficiency. If you're someone who wants maximum efficiency in your distributed training setup then you need something much more efficient than pytorch. This is where CUTLASS kernels comes in play by providing efficient cuda kernels to get truely maximum efficiency in Matmuls and other performance-heavy operations

This project further simplifies the complicated distributed gemm operation by wrapping the logic from the official CUTLASS example 82_blackwell_distributed_gemm into a custom PyTorch C++ extension.

Note that this project also have whole cutlass repository as a dependency and it will be added through a git submodule

Requirements

  • 8x NVIDIA Blackwell (Compute Capability 10.0+) or newer GPUs with NVLink
  • CUDA Toolkit 12.8 or newer
  • PyTorch (built with a compatible CUDA version)

Installation

This project includes CUTLASS as a git submodule. The setup process is straightforward:

1. Clone the repository recursively to pull in CUTLASS:

git clone --recursive https://prateekshukla1108/pytorch-distributed-gemm.git
cd pytorch-distributed-gemm

If you cloned without --recursive, you can initialize the submodule manually:

git submodule update --init --recursive

2. Install the PyTorch extension:

This command will compile the C++/CUDA source file and install the operator in your current Python environment.

pip install .

The installation process may take a few minutes as it compiles the CUTLASS kernels.

Project Structure

.
├── cutlass/                # CUTLASS git submodule
├── src/
│   └── dist_gemm_ext.cu      # PyTorch C++/CUDA binding
├── setup.py                # Build script for the extension
└── README.md               # This file

Usage

After installation, you can import and use the dist_matmul function. The operator expects three contiguous input tensors (A, B, C) and returns the result of alpha * (A @ B_T) + beta * C.

Note on Tensor Layouts:

  • A is expected to have shape (M, K).
  • B is expected to have shape (N, K) and be in a row-major layout. This is different from the standard torch.matmul which expects B to be (K, N).
  • C is the bias tensor of shape (M, N).

Limitations & Important Notes

Note that this does not have autograd support, uses cutlass::float_e4m3_t as the data type which is not the datatype which pytorch interface accepts so ensure memory layout compatibility or implement type conversions. Also to change the tensor Parallelism size you would need to change dist_gemm_ext.cu and reinstall the package

About

Pytorch Operation for distributed gemm in nvidia blackwell gpus

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published