Skip to content

Commit fb48d76

Browse files
committed
cleanup
1 parent e1abec9 commit fb48d76

File tree

3 files changed

+1
-133
lines changed

3 files changed

+1
-133
lines changed

bitsandbytes/backends/triton/ops.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,6 @@
1313
triton_available = False
1414

1515

16-
# torch compile:
17-
# 1.53s call tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional::test_dynamic_blockwise_quantization[signed=F-256-nested=T-bf16-xpu]
18-
#
19-
# triton:
20-
# 1.07s call tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional::test_dynamic_blockwise_quantization[signed=F-256-nested=T-bf16-xpu]
21-
@torch.compile
22-
def quantize_blockwise_torch(A, code, blocksize):
23-
n = A.numel()
24-
blocks = -(n // -blocksize)
25-
26-
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
27-
quantized_out = torch.empty_like(A.flatten(), dtype=torch.uint8)
28-
29-
rem = n % blocksize
30-
has_rem = rem > 0
31-
blocks = n // blocksize + has_rem
32-
A_reshaped = A.reshape(n)
33-
A_com = A_reshaped[: n - rem]
34-
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
35-
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
36-
scaled_A = torch.clamp(A_com_reshaped / absmax[: blocks - has_rem].view(-1, 1), -1, 1)
37-
scaled_A = scaled_A.reshape(-1)
38-
if has_rem:
39-
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
40-
scaled_A_rem = torch.clamp((A_reshaped[n - rem :] / absmax[-1]), -1, 1)
41-
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
42-
43-
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
44-
quantized_out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
45-
quantized_out = quantized_out.reshape(A.shape)
46-
return quantized_out, absmax
47-
48-
4916
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
5017
torch._check_is_size(blocksize)
5118
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
@@ -99,33 +66,6 @@ def dequantize_blockwise_inplace(
9966
)
10067

10168

102-
# torch compile
103-
# 1.01s call tests/test_functional.py::TestQuantize4BitFunctional::test_4bit_quant[64-fp4-fp32-xpu]
104-
#
105-
# triton
106-
# 0.80s call tests/test_functional.py::TestQuantize4BitFunctional::test_4bit_quant[64-fp4-fp32-xpu]
107-
@torch.compile
108-
def quantize_4bit_torch(
109-
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
110-
) -> tuple[torch.Tensor, torch.Tensor]:
111-
# Divide into blocks and normalize
112-
blocks = A.reshape(-1, blocksize)
113-
absmax = blocks.abs().max(dim=1).values.float()
114-
scaled = blocks / absmax.unsqueeze(-1)
115-
if quant_type == "fp4":
116-
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _FP4_QUANT_TABLE), dim=-1, keepdim=True).to(
117-
torch.uint8
118-
)
119-
else:
120-
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(
121-
torch.uint8
122-
)
123-
packed = quantized[::2] << 4 | quantized[1::2]
124-
if quant_storage != torch.uint8:
125-
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
126-
return packed, absmax.float()
127-
128-
12969
def quantize_4bit(
13070
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
13171
) -> tuple[torch.Tensor, torch.Tensor]:

bitsandbytes/backends/triton/triton_kernels.py

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,24 @@
77
from .utils import _FP4_QUANT_TABLE, _NF4_QUANT_TABLE
88

99

10-
# @triton.autotune(
11-
# configs=[
12-
# triton.Config({'SPLIT_SIZE': 64}),
13-
# triton.Config({'SPLIT_SIZE': 128}),
14-
# triton.Config({'SPLIT_SIZE': 256}),
15-
# triton.Config({'SPLIT_SIZE': 512}),
16-
# triton.Config({'SPLIT_SIZE': 1024}),
17-
# triton.Config({'SPLIT_SIZE': 2048}),
18-
# triton.Config({'SPLIT_SIZE': 4096}),
19-
# triton.Config({'SPLIT_SIZE': 8192}),
20-
# triton.Config({'SPLIT_SIZE': 16384}),
21-
# ],
22-
# key=['SPLIT_SIZE'],
23-
# )
2410
@triton.jit
2511
def dequant_8bit_kernel(
2612
a_ptr,
2713
c_ptr,
2814
quant_ptr,
2915
absmax_ptr,
30-
# bias_ptr,
3116
num_paired_elements,
3217
QUANT_BLOCK: tl.constexpr,
3318
SPLIT_SIZE: tl.constexpr,
3419
):
35-
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
20+
pid = tl.program_id(axis=0)
3621
block_start = pid * SPLIT_SIZE
3722
offsets = block_start + tl.arange(0, SPLIT_SIZE)
3823
mask = offsets < num_paired_elements
3924

4025
a = tl.load(a_ptr + offsets, mask)
4126
a = a.to(tl.uint8, bitcast=True)
4227

43-
# bias = tl.load(bias_ptr)
44-
4528
# apply conversion
4629
scaled_int8 = tl.load(quant_ptr + a, mask)
4730

@@ -52,7 +35,6 @@ def dequant_8bit_kernel(
5235
absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
5336
# apply scales
5437
out_dq = scaled_int8 * absmax
55-
# out_dq = out_dq + bias
5638

5739
offs = block_start + tl.arange(0, SPLIT_SIZE)
5840
mask = offs < num_paired_elements
@@ -79,19 +61,7 @@ def dequant_int8_blockwise(
7961

8062
@triton.autotune(
8163
configs=[
82-
# triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
83-
# triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
84-
# triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
85-
#
8664
triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
87-
#
88-
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "large"}, num_stages=2, num_warps=32),
89-
# # triton.Config({'SPLIT_NUM_BLOCKS': 2, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
90-
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=2, num_warps=32),
91-
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
92-
# triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=2, num_warps=32),
93-
# triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=4, num_warps=32),
94-
# triton.Config({'SPLIT_NUM_BLOCKS': 8, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
9565
],
9666
key=["BLOCK_SIZE"],
9767
)
@@ -124,9 +94,6 @@ def quantize_blockwise_kernel(
12494
A_normalized = A_reshaped / absmax[:, None]
12595
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
12696

127-
# This can be fruitful, but compiler should preload it
128-
# code = tl.load(code_ptr + tl.arange(0, CODE_SIZE))
129-
13097
lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
13198
upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
13299

@@ -176,24 +143,6 @@ def unite_2_int4(x, y):
176143
return (x & 0xF) | (y << 4)
177144

178145

179-
# @triton.autotune(
180-
# configs=[
181-
# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
182-
# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
183-
# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
184-
# #
185-
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
186-
# #
187-
# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "large"}, num_stages=2, num_warps=32),
188-
# # # triton.Config({'SPLIT_NUM_BLOCKS': 2, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
189-
# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=2, num_warps=32),
190-
# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
191-
# # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=2, num_warps=32),
192-
# # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=4, num_warps=32),
193-
# # triton.Config({'SPLIT_NUM_BLOCKS': 8, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
194-
# ],
195-
# key=["BLOCK_SIZE"],
196-
# )
197146
@triton.jit
198147
def quantize_4bit_blockwise_kernel(
199148
A_ptr,
@@ -261,11 +210,6 @@ def quantize_4bit_blockwise_triton(A, blocksize, code, blocks, absmax, quantized
261210

262211
split_num_blocks = 1
263212
grid = (triton.cdiv(blocks, split_num_blocks),)
264-
# grid = (1, )
265-
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
266-
# print(" blocksize, split_num_blocks: ", blocksize, split_num_blocks)
267-
# print(" blocksize, split_num_blocks: ", blocksize, split_num_blocks*2)
268-
# print("A shape: ", A.shape, " numel: ", n, " blocks: ", blocks)
269213
quantize_4bit_blockwise_kernel[grid](
270214
A_ptr=A,
271215
code_ptr=code,
@@ -280,20 +224,6 @@ def quantize_4bit_blockwise_triton(A, blocksize, code, blocks, absmax, quantized
280224
return quantized_out, absmax
281225

282226

283-
# @triton.autotune(
284-
# configs=[
285-
# # triton.Config({'SPLIT_SIZE': 64}),
286-
# # triton.Config({'SPLIT_SIZE': 128}),
287-
# # triton.Config({'SPLIT_SIZE': 256}),
288-
# triton.Config({'SPLIT_SIZE': 512}),
289-
# # triton.Config({'SPLIT_SIZE': 1024}),
290-
# # triton.Config({'SPLIT_SIZE': 2048}),
291-
# # triton.Config({'SPLIT_SIZE': 4096}),
292-
# # triton.Config({'SPLIT_SIZE': 8192}),
293-
# # triton.Config({'SPLIT_SIZE': 16384}),
294-
# ],
295-
# key=['SPLIT_SIZE'],
296-
# )
297227
@triton.jit
298228
def dequant_4bit_kernel(
299229
a_ptr, c_ptr, quant_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr

bitsandbytes/backends/xpu/ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212

1313
if triton_available:
1414
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
15-
# register_kernel("bitsandbytes::quantize_blockwise", "xpu")(quantize_blockwise_torch)
1615
register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(triton_ops.dequantize_blockwise_inplace)
1716
register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(triton_ops.dequantize_blockwise)
1817
register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
1918
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
2019
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
2120
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
22-
# register_kernel("bitsandbytes::gemv_4bit.out", "xpu")(triton_ops.gemv_4bit_inpalce)
2321
else:
2422
warnings.warn("XPU available, but trtion package is missing.")

0 commit comments

Comments
 (0)