Skip to content

Commit 1414628

Browse files
committed
[xpu/triton] Add trtion dequantization kernel
This PR adds xpu backend and trtion kernel for dequantization nf4 dtype. Trtion is an optional import. Tests: tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional implemented quantize_blockwise with binary search that works faster for XPU tests/test_linear4bit.py Signed-off-by: Dmitrii Makarenko <[email protected]>
1 parent 1d4ea6a commit 1414628

File tree

7 files changed

+912
-0
lines changed

7 files changed

+912
-0
lines changed

bitsandbytes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
if torch.cuda.is_available():
3535
from .backends.cuda import ops as cuda_ops
3636

37+
if torch.xpu.is_available():
38+
from .backends.xpu import ops as xpu_ops
39+
3740

3841
def _import_backends():
3942
"""

bitsandbytes/backends/triton/__init__.py

Whitespace-only changes.

bitsandbytes/backends/triton/ops.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from collections.abc import Sequence
2+
3+
import torch
4+
5+
# currently codes unused, kept for reference
6+
# Should be the same for quant/dequant
7+
# from bitsandbytes.functional import get_4bit_type
8+
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
9+
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
10+
11+
try:
12+
from . import triton_kernels
13+
14+
triton_available = True
15+
except ImportError as e:
16+
print("Import error:", e)
17+
triton_available = False
18+
19+
20+
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
21+
torch._check_is_size(blocksize)
22+
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
23+
24+
n = A.numel()
25+
blocks = -(n // -blocksize)
26+
27+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
28+
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
29+
30+
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
31+
out = out.reshape(A.shape)
32+
33+
return out, absmax.float()
34+
35+
36+
def dequantize_blockwise(
37+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
38+
) -> torch.Tensor:
39+
torch._check_is_size(blocksize)
40+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
41+
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
42+
43+
out = torch.empty_like(A, dtype=dtype, device=A.device)
44+
triton_kernels.dequant_int8_blockwise(
45+
A,
46+
code,
47+
absmax,
48+
out,
49+
blocksize,
50+
)
51+
52+
return out
53+
54+
55+
def dequantize_blockwise_inplace(
56+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
57+
) -> None:
58+
torch._check_is_size(blocksize)
59+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
60+
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
61+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
62+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
63+
64+
triton_kernels.dequant_int8_blockwise(
65+
A,
66+
code,
67+
absmax,
68+
out,
69+
blocksize,
70+
)
71+
72+
73+
def quantize_4bit(
74+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
75+
) -> tuple[torch.Tensor, torch.Tensor]:
76+
torch._check_is_size(blocksize)
77+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
78+
torch._check(
79+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
80+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
81+
)
82+
83+
n = A.numel()
84+
85+
# TODO: Support when weight matrix is not divisible by blocksize
86+
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
87+
88+
blocks = -(n // -(blocksize * 2))
89+
90+
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
91+
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
92+
93+
triton_kernels.quantize_4bit_blockwise_triton(
94+
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
95+
)
96+
packed = out
97+
98+
if quant_storage != torch.uint8:
99+
packed = out.squeeze().view(quant_storage).unsqueeze(1)
100+
101+
return packed, absmax.float()
102+
103+
104+
def dequantize_4bit(
105+
A: torch.Tensor,
106+
absmax: torch.Tensor,
107+
blocksize: int,
108+
quant_type: str,
109+
shape: Sequence[int],
110+
dtype: torch.dtype,
111+
) -> torch.Tensor:
112+
torch._check_is_size(blocksize)
113+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
114+
torch._check(
115+
dtype in [torch.bfloat16, torch.float16, torch.float32],
116+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
117+
)
118+
# torch._check(
119+
# A.dtype == torch.uint8,
120+
# lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}",
121+
# )
122+
# Check if this is fine and fast
123+
if A.dtype != torch.uint8:
124+
A = A.squeeze().view(torch.uint8).unsqueeze(1)
125+
126+
out = torch.empty(shape, dtype=dtype, device=A.device)
127+
128+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
129+
return out
130+
131+
132+
def dequantize_4bit_inplace(
133+
A: torch.Tensor,
134+
absmax: torch.Tensor,
135+
blocksize: int,
136+
quant_type: str,
137+
shape: Sequence[int],
138+
dtype: torch.dtype,
139+
out: torch.Tensor,
140+
) -> None:
141+
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
142+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
143+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
144+
145+
146+
def gemv_4bit(
147+
A: torch.Tensor,
148+
B: torch.Tensor,
149+
shapeB: Sequence[int],
150+
absmax: torch.Tensor,
151+
code: torch.Tensor,
152+
blocksize: int,
153+
) -> torch.Tensor:
154+
if B.dtype != torch.uint8:
155+
B = B.squeeze().view(torch.uint8).unsqueeze(1)
156+
157+
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
158+
159+
triton_kernels._dequantize_4bit_impl_passing_code(
160+
B,
161+
absmax,
162+
blocksize,
163+
code,
164+
dtype=A.dtype,
165+
out=B_dq_triton,
166+
)
167+
168+
return torch.nn.functional.linear(
169+
A,
170+
B_dq_triton,
171+
bias=None,
172+
)

0 commit comments

Comments
 (0)