Skip to content

Commit a23026c

Browse files
authored
[Triton/XPU] Support 4bit dequantization logic on Triton (#1629)
* [xpu/triton] Add trtion dequantization kernel This PR adds xpu backend and trtion kernel for dequantization nf4 dtype. Trtion is an optional import. Tests: tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional implemented quantize_blockwise with binary search that works faster for XPU tests/test_linear4bit.py Signed-off-by: Dmitrii Makarenko <[email protected]> * align with ipex code * enable test for ipex * test_kbit_backprop: skip no longer needed * remove unused --------- Signed-off-by: Dmitrii Makarenko <[email protected]>
1 parent d9333aa commit a23026c

File tree

8 files changed

+913
-8
lines changed

8 files changed

+913
-8
lines changed

bitsandbytes/backends/triton/__init__.py

Whitespace-only changes.

bitsandbytes/backends/triton/ops.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from collections.abc import Sequence
2+
3+
import torch
4+
5+
from . import triton_kernels
6+
7+
# currently codes unused, kept for reference
8+
# Should be the same for quant/dequant
9+
# from bitsandbytes.functional import get_4bit_type
10+
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
11+
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
12+
13+
14+
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
15+
torch._check_is_size(blocksize)
16+
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
17+
18+
n = A.numel()
19+
blocks = -(n // -blocksize)
20+
21+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
22+
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
23+
24+
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
25+
out = out.reshape(A.shape)
26+
27+
return out, absmax.float()
28+
29+
30+
def dequantize_blockwise(
31+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
32+
) -> torch.Tensor:
33+
torch._check_is_size(blocksize)
34+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
35+
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
36+
37+
out = torch.empty_like(A, dtype=dtype, device=A.device)
38+
triton_kernels.dequant_int8_blockwise(
39+
A,
40+
code,
41+
absmax,
42+
out,
43+
blocksize,
44+
)
45+
46+
return out
47+
48+
49+
def dequantize_blockwise_inplace(
50+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
51+
) -> None:
52+
torch._check_is_size(blocksize)
53+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
54+
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
55+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
56+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
57+
58+
triton_kernels.dequant_int8_blockwise(
59+
A,
60+
code,
61+
absmax,
62+
out,
63+
blocksize,
64+
)
65+
66+
67+
def quantize_4bit(
68+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
69+
) -> tuple[torch.Tensor, torch.Tensor]:
70+
torch._check_is_size(blocksize)
71+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
72+
torch._check(
73+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
74+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
75+
)
76+
77+
n = A.numel()
78+
79+
# TODO: Support when weight matrix is not divisible by blocksize
80+
# torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
81+
82+
blocks = -(n // -(blocksize * 2))
83+
84+
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
85+
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
86+
87+
triton_kernels.quantize_4bit_blockwise_triton(
88+
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
89+
)
90+
packed = out
91+
92+
if quant_storage != torch.uint8:
93+
packed = out.squeeze().view(quant_storage).unsqueeze(1)
94+
95+
return packed, absmax.float()
96+
97+
98+
def dequantize_4bit(
99+
A: torch.Tensor,
100+
absmax: torch.Tensor,
101+
blocksize: int,
102+
quant_type: str,
103+
shape: Sequence[int],
104+
dtype: torch.dtype,
105+
) -> torch.Tensor:
106+
torch._check_is_size(blocksize)
107+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
108+
torch._check(
109+
dtype in [torch.bfloat16, torch.float16, torch.float32],
110+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
111+
)
112+
# torch._check(
113+
# A.dtype == torch.uint8,
114+
# lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}",
115+
# )
116+
# Check if this is fine and fast
117+
if A.dtype != torch.uint8:
118+
A = A.squeeze().view(torch.uint8).unsqueeze(1)
119+
120+
out = torch.empty(shape, dtype=dtype, device=A.device)
121+
122+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
123+
return out
124+
125+
126+
def dequantize_4bit_inplace(
127+
A: torch.Tensor,
128+
absmax: torch.Tensor,
129+
blocksize: int,
130+
quant_type: str,
131+
shape: Sequence[int],
132+
dtype: torch.dtype,
133+
out: torch.Tensor,
134+
) -> None:
135+
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
136+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
137+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
138+
139+
140+
def gemv_4bit(
141+
A: torch.Tensor,
142+
B: torch.Tensor,
143+
shapeB: Sequence[int],
144+
absmax: torch.Tensor,
145+
code: torch.Tensor,
146+
blocksize: int,
147+
) -> torch.Tensor:
148+
if B.dtype != torch.uint8:
149+
B = B.squeeze().view(torch.uint8).unsqueeze(1)
150+
151+
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
152+
153+
triton_kernels._dequantize_4bit_impl_passing_code(
154+
B,
155+
absmax,
156+
blocksize,
157+
code,
158+
dtype=A.dtype,
159+
out=B_dq_triton,
160+
)
161+
162+
return torch.nn.functional.linear(
163+
A,
164+
B_dq_triton,
165+
bias=None,
166+
)

0 commit comments

Comments
 (0)