7
7
from .utils import _FP4_QUANT_TABLE , _NF4_QUANT_TABLE
8
8
9
9
10
- # @triton.autotune(
11
- # configs=[
12
- # triton.Config({'SPLIT_SIZE': 64}),
13
- # triton.Config({'SPLIT_SIZE': 128}),
14
- # triton.Config({'SPLIT_SIZE': 256}),
15
- # triton.Config({'SPLIT_SIZE': 512}),
16
- # triton.Config({'SPLIT_SIZE': 1024}),
17
- # triton.Config({'SPLIT_SIZE': 2048}),
18
- # triton.Config({'SPLIT_SIZE': 4096}),
19
- # triton.Config({'SPLIT_SIZE': 8192}),
20
- # triton.Config({'SPLIT_SIZE': 16384}),
21
- # ],
22
- # key=['SPLIT_SIZE'],
23
- # )
24
10
@triton .jit
25
11
def dequant_8bit_kernel (
26
12
a_ptr ,
27
13
c_ptr ,
28
14
quant_ptr ,
29
15
absmax_ptr ,
30
- # bias_ptr,
31
16
num_paired_elements ,
32
17
QUANT_BLOCK : tl .constexpr ,
33
18
SPLIT_SIZE : tl .constexpr ,
34
19
):
35
- pid = tl .program_id (axis = 0 ) # We use a 1D launch grid so axis is 0.
20
+ pid = tl .program_id (axis = 0 )
36
21
block_start = pid * SPLIT_SIZE
37
22
offsets = block_start + tl .arange (0 , SPLIT_SIZE )
38
23
mask = offsets < num_paired_elements
39
24
40
25
a = tl .load (a_ptr + offsets , mask )
41
26
a = a .to (tl .uint8 , bitcast = True )
42
27
43
- # bias = tl.load(bias_ptr)
44
-
45
28
# apply conversion
46
29
scaled_int8 = tl .load (quant_ptr + a , mask )
47
30
@@ -52,7 +35,6 @@ def dequant_8bit_kernel(
52
35
absmax = tl .load (absmax_ptr + abs_offsets , mask_blocked )
53
36
# apply scales
54
37
out_dq = scaled_int8 * absmax
55
- # out_dq = out_dq + bias
56
38
57
39
offs = block_start + tl .arange (0 , SPLIT_SIZE )
58
40
mask = offs < num_paired_elements
@@ -79,19 +61,7 @@ def dequant_int8_blockwise(
79
61
80
62
@triton .autotune (
81
63
configs = [
82
- # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
83
- # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
84
- # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
85
- #
86
64
triton .Config ({"SPLIT_NUM_BLOCKS" : 1 , "grf_mode" : "auto" }, num_stages = 4 , num_warps = 32 ),
87
- #
88
- # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "large"}, num_stages=2, num_warps=32),
89
- # # triton.Config({'SPLIT_NUM_BLOCKS': 2, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
90
- # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=2, num_warps=32),
91
- # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
92
- # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=2, num_warps=32),
93
- # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=4, num_warps=32),
94
- # triton.Config({'SPLIT_NUM_BLOCKS': 8, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
95
65
],
96
66
key = ["BLOCK_SIZE" ],
97
67
)
@@ -124,9 +94,6 @@ def quantize_blockwise_kernel(
124
94
A_normalized = A_reshaped / absmax [:, None ]
125
95
A_normalized = tl .clamp (A_normalized , - 1.0 , 1.0 )
126
96
127
- # This can be fruitful, but compiler should preload it
128
- # code = tl.load(code_ptr + tl.arange(0, CODE_SIZE))
129
-
130
97
lower_pivot = tl .zeros ((SPLIT_NUM_BLOCKS , BLOCK_SIZE ), dtype = tl .int32 )
131
98
upper_pivot = tl .full ((SPLIT_NUM_BLOCKS , BLOCK_SIZE ), CODE_SIZE - 1 , dtype = tl .int32 )
132
99
@@ -176,24 +143,6 @@ def unite_2_int4(x, y):
176
143
return (x & 0xF ) | (y << 4 )
177
144
178
145
179
- # @triton.autotune(
180
- # configs=[
181
- # # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
182
- # # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
183
- # # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
184
- # #
185
- # triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
186
- # #
187
- # # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "large"}, num_stages=2, num_warps=32),
188
- # # # triton.Config({'SPLIT_NUM_BLOCKS': 2, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
189
- # # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=2, num_warps=32),
190
- # # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
191
- # # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=2, num_warps=32),
192
- # # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=4, num_warps=32),
193
- # # triton.Config({'SPLIT_NUM_BLOCKS': 8, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
194
- # ],
195
- # key=["BLOCK_SIZE"],
196
- # )
197
146
@triton .jit
198
147
def quantize_4bit_blockwise_kernel (
199
148
A_ptr ,
@@ -261,11 +210,6 @@ def quantize_4bit_blockwise_triton(A, blocksize, code, blocks, absmax, quantized
261
210
262
211
split_num_blocks = 1
263
212
grid = (triton .cdiv (blocks , split_num_blocks ),)
264
- # grid = (1, )
265
- # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
266
- # print(" blocksize, split_num_blocks: ", blocksize, split_num_blocks)
267
- # print(" blocksize, split_num_blocks: ", blocksize, split_num_blocks*2)
268
- # print("A shape: ", A.shape, " numel: ", n, " blocks: ", blocks)
269
213
quantize_4bit_blockwise_kernel [grid ](
270
214
A_ptr = A ,
271
215
code_ptr = code ,
@@ -280,20 +224,6 @@ def quantize_4bit_blockwise_triton(A, blocksize, code, blocks, absmax, quantized
280
224
return quantized_out , absmax
281
225
282
226
283
- # @triton.autotune(
284
- # configs=[
285
- # # triton.Config({'SPLIT_SIZE': 64}),
286
- # # triton.Config({'SPLIT_SIZE': 128}),
287
- # # triton.Config({'SPLIT_SIZE': 256}),
288
- # triton.Config({'SPLIT_SIZE': 512}),
289
- # # triton.Config({'SPLIT_SIZE': 1024}),
290
- # # triton.Config({'SPLIT_SIZE': 2048}),
291
- # # triton.Config({'SPLIT_SIZE': 4096}),
292
- # # triton.Config({'SPLIT_SIZE': 8192}),
293
- # # triton.Config({'SPLIT_SIZE': 16384}),
294
- # ],
295
- # key=['SPLIT_SIZE'],
296
- # )
297
227
@triton .jit
298
228
def dequant_4bit_kernel (
299
229
a_ptr , c_ptr , quant_ptr , absmax_ptr , num_paired_elements , QUANT_BLOCK : tl .constexpr , SPLIT_SIZE : tl .constexpr
0 commit comments