Skip to content

Commit 1a4c8f9

Browse files
authored
Add CUTLASS-based W4A4 (#1515)
* add w4a4 * add test * hook up to AQT * fix quant api test * fix test * make threadblockswizzle a template param * re-use s8s4 cutlass template * add Alex's patch and some changes * fix aqt test * remove int4_cutlass.cu * apply alex's patch * update benchmark script * ruff * add some tuning * reduce num_stages to fit shared memory of small GPUs (<100kb) * replace torch timer with triton do_bench * ruff * use ZeroPointDomain.NONE * fix 3.7 typing * merge Aleksandar changes * run ruff * try replace torch/extension.h with torch/library.h * (alexsamardzic) improve error handling * ruff format * add note on cutlass naming
1 parent b2fb664 commit 1a4c8f9

17 files changed

+734
-444
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pandas as pd
2+
import torch
3+
from tqdm import tqdm
4+
from triton.testing import do_bench
5+
6+
from torchao.ops import (
7+
rowwise_scaled_linear_cutlass_s4s4,
8+
rowwise_scaled_linear_cutlass_s8s4,
9+
)
10+
11+
12+
def benchmark_microseconds(f, *args):
13+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
14+
15+
16+
def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
17+
assert A_nbits in (4, 8) and B_nbits in (4, 8)
18+
19+
dev = torch.device("cuda")
20+
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
21+
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
22+
B = torch.randint(
23+
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
24+
)
25+
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
26+
C = None
27+
28+
return A, A_scale, B, B_scale, C
29+
30+
31+
def benchmark(m: int, k: int, n: int):
32+
dev = torch.device("cuda")
33+
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
34+
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
35+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)
36+
37+
A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
38+
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
39+
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
40+
)
41+
42+
A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
43+
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
44+
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
45+
)
46+
47+
return {
48+
"m": m,
49+
"k": k,
50+
"n": n,
51+
"fp16_latency (ms)": fp16_time,
52+
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
53+
"s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
54+
"rowwise_scaled_linear_cutlass_s4s4 latency (ms)": rowwise_scaled_linear_cutlass_s4s4_time,
55+
"s4s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s4s4_time,
56+
}
57+
58+
59+
if __name__ == "__main__":
60+
k_vals = (8192, 8192, 8192, 28672)
61+
n_vals = (8192, 10240, 57344, 8192)
62+
63+
results = []
64+
for m in tqdm([1 << i for i in range(10)]):
65+
for n, k in zip(n_vals, k_vals):
66+
results.append(benchmark(m, k, n))
67+
68+
df = pd.DataFrame(results)
69+
df.to_csv("rowwise_scaled_linear_cutlass_time_results.csv", index=False)
70+
print(df.to_markdown(index=False))

benchmarks/benchmark_s8s4_cutlass.py

-52
This file was deleted.

setup.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -240,30 +240,42 @@ def get_extensions():
240240
extra_compile_args["nvcc"].append("-g")
241241
extra_link_args.append("/DEBUG")
242242

243+
this_dir = os.path.dirname(os.path.curdir)
244+
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
245+
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
246+
247+
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
248+
cuda_sources = list(
249+
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
250+
)
251+
252+
if use_cuda:
253+
sources += cuda_sources
254+
243255
use_cutlass = False
244256
if use_cuda and not IS_WINDOWS:
245257
use_cutlass = True
246258
cutlass_dir = os.path.join(third_party_path, "cutlass")
247259
cutlass_include_dir = os.path.join(cutlass_dir, "include")
260+
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
248261
if use_cutlass:
249262
extra_compile_args["nvcc"].extend(
250263
[
251264
"-DTORCHAO_USE_CUTLASS",
252265
"-I" + cutlass_include_dir,
266+
"-I" + cutlass_extensions_include_dir,
253267
]
254268
)
255-
256-
this_dir = os.path.dirname(os.path.curdir)
257-
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
258-
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
259-
260-
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
261-
cuda_sources = list(
262-
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
263-
)
264-
265-
if use_cuda:
266-
sources += cuda_sources
269+
else:
270+
# Remove CUTLASS-based kernels from the cuda_sources list. An
271+
# assumption is that these files will have "cutlass" in its
272+
# name.
273+
cutlass_sources = list(
274+
glob.glob(
275+
os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True
276+
)
277+
)
278+
sources = [s for s in sources if s not in cutlass_sources]
267279

268280
ext_modules = []
269281
if len(sources) > 0:

test/dtypes/test_affine_quantized.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
1212
from torchao.quantization import (
1313
float8_weight_only,
14+
int4_dynamic_activation_int4_weight,
1415
int4_weight_only,
1516
int8_dynamic_activation_int4_weight,
1617
int8_dynamic_activation_int8_weight,
@@ -61,6 +62,7 @@ def get_quantization_functions(
6162
layout=CutlassInt4PackedLayout(),
6263
)
6364
)
65+
base_functions.append(int4_dynamic_activation_int4_weight())
6466

6567
if do_sparse:
6668
base_functions.append(
+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import itertools
2+
3+
import pytest
4+
import torch
5+
6+
from torchao.ops import (
7+
rowwise_scaled_linear_cutlass_s4s4,
8+
rowwise_scaled_linear_cutlass_s8s4,
9+
)
10+
from torchao.quantization.utils import group_quantize_tensor_symmetric
11+
12+
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
13+
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
14+
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
15+
(2, 512, 128),
16+
(3, 2048, 2048),
17+
(4, 3584, 640),
18+
(13, 8704, 8576),
19+
(26, 18944, 1664),
20+
(67, 6656, 1408),
21+
]
22+
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
23+
ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
24+
itertools.product(
25+
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
26+
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
27+
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
28+
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
29+
)
30+
)
31+
32+
33+
def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias):
34+
assert xq_bits in [4, 8]
35+
assert wq_bits in [4, 8]
36+
37+
size_m, size_n, size_k = size_mnk
38+
39+
x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
40+
w = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
41+
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
42+
43+
x_2d = x.view(-1, x.shape[-1])
44+
xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric(
45+
x_2d, xq_bits, size_k, dtype
46+
)
47+
assert torch.all(xq_2d_zeros == 0)
48+
xq_s8 = xq_2d_s8.reshape(x.shape)
49+
if xq_bits == 4:
50+
xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF)
51+
else:
52+
xq = xq_s8
53+
xq_scales = xq_2d_scales.reshape(x.shape[:-1])
54+
55+
wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric(
56+
w, wq_bits, size_n, dtype
57+
)
58+
assert torch.all(wq_zeros == 0)
59+
if wq_bits == 4:
60+
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF)
61+
else:
62+
wq = wq_s8
63+
64+
# If torch.nn.functional.linear(x, w, bias) used as reference, the
65+
# error would be too big. The calculation below is approximately
66+
# what rowwise_scaled_linear_cutlass kernel is doing (except that
67+
# matrix multiplication is over integers there).
68+
size_m_2d = x_2d.shape[0]
69+
output_ref = (
70+
(xq_2d_s8.float() @ wq_s8.float().T)
71+
* xq_2d_scales.view(size_m_2d, 1)
72+
* wq_scales.view(1, size_n)
73+
)
74+
if bias is not None:
75+
output_ref += bias
76+
output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,))
77+
78+
fn_inputs = (xq, xq_scales, wq, wq_scales, bias)
79+
try:
80+
output = op(*fn_inputs)
81+
except NotImplementedError:
82+
pytest.xfail("operator not implemented")
83+
84+
torch.testing.assert_close(output, output_ref)
85+
86+
87+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
88+
@pytest.mark.parametrize(
89+
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
90+
)
91+
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
92+
run_test_for_op(
93+
rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias
94+
)
95+
96+
97+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
98+
@pytest.mark.parametrize(
99+
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
100+
)
101+
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
102+
run_test_for_op(
103+
rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias
104+
)

test/test_s8s4_linear_cutlass.py

-77
This file was deleted.

torchao/csrc/README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ The goal is that you can focus on just writing your custom CUDA or C++ kernel an
88

99
To learn more about custom ops in PyTorch you can refer to the [PyTorch Custom Operators Landing Page](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
1010

11-
1211
## How to add your own kernel in ao
1312

1413
We've integrated several kernels which you can use as a template for your own kernels. `tensor_core_tiled_layout` is the most straight-forward to get started with.
@@ -23,6 +22,8 @@ And that's it! Once CI passes and your code merged you'll be able to point peopl
2322

2423
If you'd like to learn more please check out [torch.library](https://pytorch.org/docs/main/library.html)
2524

25+
Note: All CUTLASS-based kernels should have `cutlass` in the name of their `.cu` files e.g. `rowwise_scaled_linear_cutlass_s4s4.cu`
26+
2627
## Required dependencies
2728

2829
The important dependencies are already taken care of in our CI so feel free to test in CI directly

0 commit comments

Comments
 (0)