Skip to content

Commit 14520d1

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Enable FP4 CUTLASS GEMM and CUDA quantization kernels (#4004)
Summary: X-link: facebookresearch/FBGEMM#1091 Pull Request resolved: #4004 Enable MXFP4 and NVFP4 CUTLASS GEMM and NVFP4 CUDA quantization kernels Reviewed By: jianyuh Differential Revision: D69505435 fbshipit-source-id: 7b438628663efec47851bb3908ae2f74ee9e8261
1 parent 517b73d commit 14520d1

33 files changed

+1928
-3
lines changed

.github/scripts/fbgemm_gpu_build.bash

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ __configure_fbgemm_gpu_build_cuda () {
251251
# https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L187
252252
# https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp#L224
253253
if [[ $cuda_version_nvcc == *"V12.8"* ]]; then
254-
local arch_list="7.0;8.0;9.0;9.0a;10.0;10.0a;12.0;12.0a"
254+
local arch_list="7.0;8.0;9.0;9.0a;10.0a;12.0a"
255255

256256
elif [[ $cuda_version_nvcc == *"V12.6"* ]] ||
257257
[[ $cuda_version_nvcc == *"V12.4"* ]] ||

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
grouped_gemm,
2626
grouped_gemm_fp8_rowwise,
2727
)
28-
from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle
28+
from fbgemm_gpu.experimental.gen_ai.quantize import (
29+
quantize_int4_preshuffle,
30+
scaled_fp4_quant,
31+
)
2932

3033
try:
3134
from tinygemm.utils import group_quantize_tensor
@@ -1962,3 +1965,43 @@ def hip(self) -> bool:
19621965
def cuda(self) -> bool:
19631966
# This op is not always supported.
19641967
return MACHETE_ENABLED
1968+
1969+
1970+
@register_quantize_op
1971+
class FP4Gemm(QuantizeOpBase):
1972+
"""
1973+
FP4 matmul with block-wise scaling.
1974+
"""
1975+
1976+
def quantize(self, x, w):
1977+
x_global_scale = ((448.0 * 6.0) / torch.amax(x.flatten(), dim=-1)).to(
1978+
torch.float32
1979+
)
1980+
w_global_scale = ((448.0 * 6.0) / torch.amax(w.flatten(), dim=-1)).to(
1981+
torch.float32
1982+
)
1983+
global_scale = 1 / (x_global_scale * w_global_scale)
1984+
1985+
xq, x_scale = scaled_fp4_quant(x, x_global_scale)
1986+
wq, w_scale = scaled_fp4_quant(w, w_global_scale)
1987+
return xq, wq, x_scale, w_scale, global_scale
1988+
1989+
def compute(self, xq, wq, x_scale, w_scale, global_scale):
1990+
return torch.ops.fbgemm.f4f4bf16(xq, wq, x_scale, w_scale, global_scale)
1991+
1992+
def quantize_and_compute(self, x, w):
1993+
xq, wq, x_scale, w_scale, global_scale = self.quantize(x, w)
1994+
return self.compute(xq, wq, x_scale, w_scale, global_scale)
1995+
1996+
@property
1997+
def name(self) -> str:
1998+
return "cutlass_f4f4bf16"
1999+
2000+
@property
2001+
def hip(self) -> bool:
2002+
# F4F4BF16 only supported for cuda.
2003+
return False
2004+
2005+
@property
2006+
def cuda(self) -> bool:
2007+
return True

fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,57 @@ def _quantize(
162162
wq, scales = _quantize(w, dtype=dtype)
163163

164164
return wq, scales
165+
166+
167+
def scaled_fp4_quant(
168+
input: torch.Tensor, input_global_scale: torch.Tensor
169+
) -> Tuple[torch.Tensor, torch.Tensor]:
170+
"""
171+
Quantize input tensor to FP4 and return quantized tensor and scale.
172+
This function quantizes the last dimension of the given tensor `input`. For
173+
every 16 consecutive elements, a single dynamically computed scaling factor
174+
is shared. This scaling factor is quantized using the `input_global_scale`
175+
and is stored in a swizzled layout (see
176+
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
177+
Args:
178+
input: The input tensor to be quantized to FP4
179+
input_global_scale: A scalar scaling factor for the entire tensor.
180+
Returns:
181+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
182+
two values are packed into a uint8 and float8_e4m3 scaling factors
183+
in the sizzled layout.
184+
"""
185+
assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
186+
other_dims = 1 if input.ndim == 1 else -1
187+
input = input.reshape(other_dims, input.shape[-1])
188+
m, n = input.shape
189+
block_size = 16
190+
device = input.device
191+
192+
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
193+
assert input.dtype in (
194+
torch.float16,
195+
torch.bfloat16,
196+
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
197+
198+
# Two fp4 values will be packed into an uint8.
199+
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
200+
201+
# We use the rounded values to store the swizzled values. Due to the
202+
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
203+
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
204+
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
205+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
206+
def round_up(x: int, y: int) -> int:
207+
return (x + y - 1) // y * y
208+
209+
rounded_m = round_up(m, 128)
210+
scale_n = n // block_size
211+
rounded_n = round_up(scale_n, 4)
212+
output_scale = torch.empty(
213+
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
214+
)
215+
216+
torch.ops.fbgemm.scaled_fp4_quant(output, input, output_scale, input_global_scale)
217+
output_scale = output_scale.view(torch.float8_e4m3fn)
218+
return output, output_scale
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
#include <cutlass/util/device_memory.h>
12+
// clang-format on
13+
14+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
15+
#include "f4f4bf16/f4f4bf16_manifest.cuh"
16+
#endif
17+
18+
namespace fbgemm_gpu {
19+
20+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
21+
22+
at::Tensor dispatch_f4f4bf16_kernel(
23+
at::Tensor XQ, // FP4
24+
at::Tensor WQ, // FP4
25+
at::Tensor x_scale,
26+
at::Tensor w_scale,
27+
at::Tensor global_scale,
28+
bool use_mx = false) {
29+
auto M = XQ.size(0);
30+
auto K = XQ.size(1);
31+
auto N = WQ.size(0);
32+
auto BLOCK_SIZE = 16;
33+
TORCH_CHECK(
34+
N % BLOCK_SIZE == 0 && K % BLOCK_SIZE == 0,
35+
"Weight dimensions N and K must be multiples of block size 16");
36+
37+
// MXFP4
38+
if (use_mx) {
39+
if (M <= 128) {
40+
if (N <= 1024) {
41+
return f4f4bf16_256_128_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale);
42+
} else if (N <= 2048) {
43+
return f4f4bf16_256_192_4_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
44+
} else {
45+
return f4f4bf16_128_128_4_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
46+
}
47+
} else if (M <= 2048) {
48+
if (N <= 2048) {
49+
return f4f4bf16_256_128_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale);
50+
} else if (N <= 8192) {
51+
return f4f4bf16_128_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
52+
} else {
53+
return f4f4bf16_256_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
54+
}
55+
} else if (M <= 4096) {
56+
if (N <= 4096) {
57+
return f4f4bf16_256_256_4_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
58+
} else if (N <= 8192) {
59+
return f4f4bf16_256_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
60+
} else {
61+
return f4f4bf16_256_128_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale);
62+
}
63+
} else if (M <= 8192) {
64+
if (N <= 4096) {
65+
return f4f4bf16_256_256_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale);
66+
} else if (N <= 8192) {
67+
return f4f4bf16_256_256_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale);
68+
} else {
69+
return f4f4bf16_128_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
70+
}
71+
} else if (M <= 16384) {
72+
if (N <= 2048) {
73+
return f4f4bf16_256_256_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale);
74+
} else if (N <= 8192) {
75+
return f4f4bf16_128_192_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale);
76+
} else {
77+
return f4f4bf16_128_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
78+
}
79+
} else if (M <= 32768) {
80+
if (N <= 1024) {
81+
return f4f4bf16_256_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
82+
} else if (N <= 4096) {
83+
return f4f4bf16_128_192_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale);
84+
} else {
85+
return f4f4bf16_256_192_4_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
86+
}
87+
} else if (M <= 65536) {
88+
if (N <= 2048) {
89+
return f4f4bf16_256_192_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale);
90+
} else if (N <= 4096) {
91+
return f4f4bf16_256_192_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale);
92+
} else {
93+
return f4f4bf16_256_256_2_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
94+
}
95+
} else {
96+
if (N <= 1024) {
97+
return f4f4bf16_256_192_2_4_1_t(XQ, WQ, x_scale, w_scale, global_scale);
98+
} else {
99+
return f4f4bf16_256_256_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale);
100+
}
101+
}
102+
}
103+
// NVFP4
104+
else {
105+
if (M <= 128) {
106+
if (N <= 1024) {
107+
return f4f4bf16_256_128_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale);
108+
} else if (N <= 2048) {
109+
return f4f4bf16_256_192_4_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
110+
} else {
111+
return f4f4bf16_128_128_4_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
112+
}
113+
} else if (M <= 2048) {
114+
if (N <= 2048) {
115+
return f4f4bf16_256_128_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale);
116+
} else if (N <= 8192) {
117+
return f4f4bf16_128_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
118+
} else {
119+
return f4f4bf16_256_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
120+
}
121+
} else if (M <= 4096) {
122+
if (N <= 4096) {
123+
return f4f4bf16_256_256_4_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
124+
} else if (N <= 8192) {
125+
return f4f4bf16_256_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
126+
} else {
127+
return f4f4bf16_256_128_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale);
128+
}
129+
} else if (M <= 8192) {
130+
if (N <= 4096) {
131+
return f4f4bf16_256_256_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale);
132+
} else if (N <= 8192) {
133+
return f4f4bf16_256_256_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale);
134+
} else {
135+
return f4f4bf16_128_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
136+
}
137+
} else if (M <= 16384) {
138+
if (N <= 2048) {
139+
return f4f4bf16_256_256_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale);
140+
} else if (N <= 8192) {
141+
return f4f4bf16_128_192_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale);
142+
} else {
143+
return f4f4bf16_128_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
144+
}
145+
} else if (M <= 32768) {
146+
if (N <= 1024) {
147+
return f4f4bf16_256_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
148+
} else if (N <= 4096) {
149+
return f4f4bf16_128_192_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale);
150+
} else {
151+
return f4f4bf16_256_192_4_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
152+
}
153+
} else if (M <= 65536) {
154+
if (N <= 2048) {
155+
return f4f4bf16_256_192_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale);
156+
} else if (N <= 4096) {
157+
return f4f4bf16_256_192_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale);
158+
} else {
159+
return f4f4bf16_256_256_2_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
160+
}
161+
} else {
162+
if (N <= 1024) {
163+
return f4f4bf16_256_192_2_4_1_f(XQ, WQ, x_scale, w_scale, global_scale);
164+
} else {
165+
return f4f4bf16_256_256_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale);
166+
}
167+
}
168+
}
169+
}
170+
171+
at::Tensor f4f4bf16(
172+
at::Tensor XQ, // FP4
173+
at::Tensor WQ, // FP4
174+
at::Tensor x_scale,
175+
at::Tensor w_scale,
176+
at::Tensor global_scale,
177+
bool use_mx = false) {
178+
return dispatch_f4f4bf16_kernel(
179+
XQ, WQ, x_scale, w_scale, global_scale, use_mx);
180+
}
181+
182+
#else
183+
184+
at::Tensor f4f4bf16(
185+
at::Tensor XQ, // FP4
186+
at::Tensor WQ, // FP4
187+
at::Tensor x_scale,
188+
at::Tensor w_scale,
189+
at::Tensor global_scale,
190+
bool use_mx = false) {
191+
throw std::runtime_error(
192+
"CUDA version is older than 12.8"); // requires CUDA>=12.8
193+
}
194+
195+
#endif
196+
197+
} // namespace fbgemm_gpu
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f4f4bf16_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
14+
15+
at::Tensor f4f4bf16_128_128_4_1_1_f(
16+
at::Tensor XQ, // FP4
17+
at::Tensor WQ, // FP4
18+
at::Tensor x_scale,
19+
at::Tensor w_scale,
20+
at::Tensor global_scale) {
21+
// Dispatch this kernel to the correct underlying implementation.
22+
return _f4f4bf16<128, 128, 4, 1, 1, false>(
23+
XQ, WQ, x_scale, w_scale, global_scale);
24+
}
25+
26+
#endif
27+
28+
} // namespace fbgemm_gpu
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f4f4bf16_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
14+
15+
at::Tensor f4f4bf16_128_128_4_1_1_t(
16+
at::Tensor XQ, // FP4
17+
at::Tensor WQ, // FP4
18+
at::Tensor x_scale,
19+
at::Tensor w_scale,
20+
at::Tensor global_scale) {
21+
// Dispatch this kernel to the correct underlying implementation.
22+
return _f4f4bf16<128, 128, 4, 1, 1, true>(
23+
XQ, WQ, x_scale, w_scale, global_scale);
24+
}
25+
26+
#endif
27+
28+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)