Skip to content

Commit 679cedc

Browse files
committed
[xpu/triton] Add trtion dequantization kernel
This PR adds xpu backend and trtion kernel for dequantization nf4 dtype. Trtion is an optional import. Tests: tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional implemented quantize_blockwise with binary search that works faster for XPU tests/test_linear4bit.py Contains WA for gemv_4bit for XPU, for some reason directly passed code causes errors in several tests. For example: `tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=128-uint8-fp32-fc1-fp4-DQ_True-xpu]` Signed-off-by: Dmitrii Makarenko <[email protected]>
1 parent 42bc729 commit 679cedc

File tree

7 files changed

+603
-0
lines changed

7 files changed

+603
-0
lines changed

bitsandbytes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
if torch.cuda.is_available():
3535
from .backends.cuda import ops as cuda_ops
3636

37+
if torch.xpu.is_available():
38+
from .backends.xpu import ops as xpu_ops
39+
3740

3841
def _import_backends():
3942
"""

bitsandbytes/backends/xpu/__init__.py

Whitespace-only changes.

bitsandbytes/backends/xpu/ops.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
from collections.abc import Sequence
2+
import warnings
3+
4+
import torch
5+
6+
from ..._ops import register_kernel
7+
8+
try:
9+
from . import triton_kernels
10+
11+
triton_available = True
12+
except ImportError:
13+
triton_available = False
14+
15+
16+
@torch.compile
17+
def quantize_blockwise_torch(A, blocksize, code, absmax, quantized_out):
18+
n = A.numel()
19+
rem = n % blocksize
20+
has_rem = rem > 0
21+
blocks = n // blocksize + has_rem
22+
A_reshaped = A.reshape(n)
23+
A_com = A_reshaped[: n - rem]
24+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize) # (1, 64)
25+
# print("A_com_reshaped: ", A_com_reshaped, " shape: ", A_com_reshaped.shape)
26+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
27+
divi = A_com_reshaped / absmax[: blocks - has_rem].view(-1, 1)
28+
# print("A divided: ", divi, " absmax: ", absmax[: blocks - has_rem].view(-1, 1), " shape: ", absmax[: blocks - has_rem].view(-1, 1).shape)
29+
scaled_A = torch.clamp(A_com_reshaped / absmax[: blocks - has_rem].view(-1, 1), -1, 1)
30+
# print("A normalized: ", scaled_A)
31+
scaled_A = scaled_A.reshape(-1)
32+
if has_rem:
33+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
34+
scaled_A_rem = torch.clamp((A_reshaped[n - rem :] / absmax[-1]), -1, 1)
35+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
36+
37+
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
38+
# print("ref diff with code: ", diff.dtype, " shape: ", diff.shape)
39+
# print("ref shape: ", A.shape)
40+
quantized_out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
41+
# print("quantized out flat: ", quantized_out.flatten())
42+
# print("quantized_out: ", quantized_out, " shape: ", quantized_out.shape)
43+
44+
return quantized_out, absmax
45+
46+
47+
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
48+
torch._check_is_size(blocksize)
49+
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
50+
51+
n = A.numel()
52+
blocks = -(n // -blocksize)
53+
54+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
55+
# print("absmax size: ", absmax.shape)
56+
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
57+
58+
# ref_absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
59+
# ref_out = torch.empty_like(A, dtype=torch.uint8)
60+
# ref_out, ref_absmax = quantize_blockwise_with_code(A, blocksize, code, ref_absmax, ref_out)
61+
# out, absmax = quantize_blockwise_torch(A, blocksize, code, absmax, out)
62+
63+
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
64+
out = out.reshape(A.shape)
65+
return out, absmax
66+
67+
68+
def dequantize_blockwise(
69+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
70+
) -> torch.Tensor:
71+
torch._check_is_size(blocksize)
72+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
73+
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
74+
75+
out = torch.empty_like(A, dtype=dtype, device=A.device)
76+
triton_kernels.dequant_int8_fp16(
77+
A,
78+
code,
79+
absmax,
80+
out,
81+
blocksize,
82+
)
83+
84+
return out
85+
86+
87+
_NF4_QUANT_TABLE = torch.tensor(
88+
[
89+
-1.0,
90+
-0.6961928009986877,
91+
-0.5250730514526367,
92+
-0.39491748809814453,
93+
-0.28444138169288635,
94+
-0.18477343022823334,
95+
-0.09105003625154495,
96+
0.0,
97+
0.07958029955625534,
98+
0.16093020141124725,
99+
0.24611230194568634,
100+
0.33791524171829224,
101+
0.44070982933044434,
102+
0.5626170039176941,
103+
0.7229568362236023,
104+
1.0,
105+
],
106+
dtype=torch.float32,
107+
device="xpu",
108+
)
109+
110+
_FP4_QUANT_TABLE = torch.tensor(
111+
[
112+
0.0000,
113+
0.0052,
114+
0.6667,
115+
1.0000,
116+
0.3333,
117+
0.5000,
118+
0.1667,
119+
0.2500,
120+
0.0000,
121+
-0.0052,
122+
-0.6667,
123+
-1.0000,
124+
-0.3333,
125+
-0.5000,
126+
-0.1667,
127+
-0.2500,
128+
],
129+
dtype=torch.float32,
130+
device="xpu",
131+
)
132+
133+
134+
@register_kernel("bitsandbytes::quantize_4bit", "xpu")
135+
def _(
136+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
137+
) -> tuple[torch.Tensor, torch.Tensor]:
138+
torch._check_is_size(blocksize)
139+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
140+
torch._check(
141+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
142+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
143+
)
144+
145+
n = A.numel()
146+
147+
# TODO: Support when weight matrix is not divisible by blocksize
148+
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
149+
150+
# Divide into blocks and normalize
151+
blocks = A.reshape(-1, blocksize)
152+
absmax = blocks.abs().max(dim=1).values.float()
153+
scaled = blocks / absmax.unsqueeze(-1)
154+
155+
# Quantize with the lookup table
156+
if quant_type == "fp4":
157+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _FP4_QUANT_TABLE), dim=-1, keepdim=True).to(
158+
torch.uint8
159+
)
160+
else:
161+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(
162+
torch.uint8
163+
)
164+
165+
# Pack two quantized values per byte
166+
packed = quantized[::2] << 4 | quantized[1::2]
167+
168+
if quant_storage != torch.uint8:
169+
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
170+
171+
return packed, absmax.float()
172+
173+
174+
def dequantize_4bit(
175+
A: torch.Tensor,
176+
absmax: torch.Tensor,
177+
blocksize: int,
178+
quant_type: str,
179+
shape: Sequence[int],
180+
dtype: torch.dtype,
181+
) -> torch.Tensor:
182+
torch._check_is_size(blocksize)
183+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
184+
torch._check(
185+
dtype in [torch.bfloat16, torch.float16, torch.float32],
186+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
187+
)
188+
# torch._check(
189+
# A.dtype == torch.uint8,
190+
# lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}",
191+
# )
192+
# Check if this is fine and fast
193+
if A.dtype != torch.uint8:
194+
A = A.squeeze().view(torch.uint8).unsqueeze(1)
195+
196+
out = torch.empty(shape, dtype=dtype, device=A.device)
197+
198+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
199+
return out
200+
201+
202+
def dequantize_4bit_inplace(
203+
A: torch.Tensor,
204+
absmax: torch.Tensor,
205+
blocksize: int,
206+
quant_type: str,
207+
shape: Sequence[int],
208+
dtype: torch.dtype,
209+
out: torch.Tensor,
210+
) -> None:
211+
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
212+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
213+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
214+
215+
216+
# tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=128-uint8-fp32-fc1-fp4-DQ_True-xpu]
217+
def gemv_4bit(
218+
A: torch.Tensor,
219+
B: torch.Tensor,
220+
shapeB: Sequence[int],
221+
absmax: torch.Tensor,
222+
code: torch.Tensor,
223+
blocksize: int,
224+
) -> torch.Tensor:
225+
# TODO: We need to determine whether `code` is NF4, FP4, or other.
226+
# Right now we assume NF4, as this is the only one supported on CPU.
227+
quant_type = "fp4" if code[1] > 0 else "nf4"
228+
B_dq = dequantize_4bit(B, absmax, blocksize, quant_type, shapeB, A.dtype)
229+
230+
# For some reason directly passing code causes errors in some cases like:
231+
# tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=128-uint8-fp32-fc1-fp4-DQ_True-xpu]
232+
#
233+
# B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device)
234+
# if B.dtype != torch.uint8:
235+
# B = B.squeeze().view(torch.uint8).unsqueeze(1)
236+
237+
# triton_kernels._dequantize_4bit_impl_passing_code(
238+
# B,
239+
# absmax,
240+
# blocksize,
241+
# code,
242+
# dtype=A.dtype,
243+
# out=B_dq,
244+
# )
245+
246+
# User called gemv with B.t(), so we need to transpose it back.
247+
# if B.shape[0] == 1:
248+
# B_dq = B_dq.t()
249+
250+
return torch.nn.functional.linear(
251+
A,
252+
B_dq,
253+
bias=None,
254+
)
255+
256+
257+
if triton_available:
258+
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(quantize_blockwise)
259+
register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(dequantize_blockwise)
260+
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(dequantize_4bit_inplace)
261+
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(dequantize_4bit)
262+
register_kernel("bitsandbytes::gemv_4bit", "xpu")(gemv_4bit)
263+
else:
264+
warnings.warn("XPU available, but trtion package is missing.")

0 commit comments

Comments
 (0)