| 
21 | 21 | # testing utilities  | 
22 | 22 | from triton_kernels.testing import assert_close, compute_actual_scale  | 
23 | 23 | # target-specific utilities  | 
24 |  | -from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4  | 
25 |  | - | 
 | 24 | +from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4  | 
 | 25 | +from icecream import ic  | 
26 | 26 | # ---------------  | 
27 | 27 | # initialize data  | 
28 | 28 | # ---------------  | 
@@ -168,100 +168,12 @@ class Case:  | 
168 | 168 |     ", ".join(f.name for f in fields(Case)),  | 
169 | 169 |     [  | 
170 | 170 |         tuple(getattr(case, f.name) for f in fields(Case)) for case in [  | 
171 |  | -            # Zero-sized args:  | 
172 |  | -            Case(0, 5, 7, "ragged", "float16", "float16"),  | 
173 |  | -            Case(5, 0, 7, "ragged", "float16", "float16"),  | 
174 |  | -            Case(5, 7, 0, "ragged", "float16", "float16"),  | 
175 |  | -            Case(0, 5, 7, "batched", "float16", "float16"),  | 
176 |  | -            Case(5, 0, 7, "batched", "float16", "float16"),  | 
177 |  | -            Case(5, 7, 0, "batched", "float16", "float16"),  | 
178 |  | -            # Non-mx types:  | 
179 |  | -            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4),  | 
180 |  | -            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2),  | 
181 |  | -            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4),  | 
182 |  | -            Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2),  | 
183 |  | -            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),  | 
184 |  | -            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),  | 
185 |  | -            Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1),  | 
186 |  | -            Case(16, 256, 256, "batched", "float16", "float16", 5, 1),  | 
187 |  | -            Case(16, 256, 256, "ragged", "float16", "float16", 3, 1),  | 
188 |  | -            Case(256, 256, 256, "ragged", "float16", "float16", 4, 1),  | 
189 |  | -            Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3),  | 
190 |  | -            Case(300, 400, 400, "batched", "float16", "float16", 5, 1),  | 
191 |  | -            Case(300, 400, 400, "ragged", "float16", "float16"),  | 
192 |  | -            Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"),  | 
193 |  | -            Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1),  | 
194 |  | -            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1),  | 
195 |  | -            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2),  | 
196 |  | -            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4),  | 
197 |  | -            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2),  | 
198 |  | -            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2),  | 
199 |  | -            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2),  | 
200 |  | -            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2),  | 
201 |  | -            Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1),  | 
202 |  | -            Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2),  | 
203 |  | -            Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9),  | 
204 |  | -            # mx types:  | 
205 |  | -            Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1),  | 
206 |  | -            Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),  | 
207 |  | -            Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1),  | 
208 |  | -            Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),  | 
209 |  | -            Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),  | 
210 |  | -            Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),  | 
211 |  | -            Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9),  | 
212 |  | -            Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),  | 
213 |  | -            Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4),  | 
214 |  | -            Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),  | 
215 |  | -            Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4),  | 
216 |  | -            Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),  | 
217 |  | -            Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),  | 
218 |  | -            Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),  | 
219 |  | -            Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),  | 
220 |  | -            Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),  | 
221 |  | -            Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1),  | 
222 |  | -            Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),  | 
223 |  | -            Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),  | 
224 |  | -            Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),  | 
225 |  | -            Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),  | 
226 |  | -            Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4),  | 
227 |  | -            Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),  | 
228 |  | -            Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4),  | 
229 |  | -            Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),  | 
230 |  | -            Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4),  | 
231 |  | -            Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),  | 
232 |  | -            Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),  | 
233 |  | -            Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False),  | 
234 |  | -            Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),  | 
235 |  | -            Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),  | 
236 |  | -            Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),  | 
237 | 171 |             Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),  | 
238 |  | -            Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),  | 
239 |  | -            Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),  | 
240 |  | -            Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),  | 
241 |  | -            Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),  | 
242 |  | -            Case(300, 512, 512, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),  | 
243 |  | -            Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),  | 
244 |  | -            Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4),  | 
245 |  | -            Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),  | 
246 |  | -            Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4),  | 
247 |  | -            Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),  | 
248 |  | -            # AMD  | 
249 |  | -            Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"),  | 
250 |  | -            Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1),  | 
251 |  | -            Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2),  | 
252 |  | -            Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2),  | 
253 |  | -            Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2),  | 
254 |  | -            Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"),  | 
255 |  | -            Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),  | 
256 |  | -            Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2),  | 
257 |  | -            Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2),  | 
258 |  | -        ] + [  | 
259 |  | -            Case(320, 400, 400, mode, dtype, dtype, x_transpose=x_transpose, w_transpose=w_transpose, y_transpose=y_transpose)  | 
260 |  | -            for mode in ("batched", "ragged")  | 
261 |  | -            for dtype in ("float16", "float8_e5m2")  | 
262 |  | -            for x_transpose in (False, True)  | 
263 |  | -            for w_transpose in (False, True)  | 
264 |  | -            for y_transpose in (False, True)  | 
 | 172 | +            #Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),  | 
 | 173 | +            Case(1000, 704, 800, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9),  | 
 | 174 | +            #Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),  | 
 | 175 | +
  | 
 | 176 | +            #Case(1111, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),  | 
265 | 177 |         ]  | 
266 | 178 |     ],  | 
267 | 179 | )  | 
@@ -355,6 +267,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas  | 
355 | 267 |             "block_k": 256  | 
356 | 268 |         })  | 
357 | 269 | 
 
  | 
 | 270 | +    ic(constraints)  | 
358 | 271 |     opt_flags.update_opt_flags_constraints(constraints)  | 
359 | 272 | 
 
  | 
360 | 273 |     weight_mxfp = weight_dtype_str.startswith("mx")  | 
@@ -555,16 +468,6 @@ def _make_tensor(shape, dtype, trans):  | 
555 | 468 |         )  | 
556 | 469 | 
 
  | 
557 | 470 | 
 
  | 
558 |  | -def test_set_idle_sms():  | 
559 |  | -    if not is_cuda():  | 
560 |  | -        pytest.skip("Only supported on CUDA")  | 
561 |  | -    from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags  | 
562 |  | -    num_idle_sms = 24  | 
563 |  | -    matmul_ogs_set_idle_sms(num_idle_sms)  | 
564 |  | -    flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \  | 
565 |  | -                           1, 1024, 1024, 1024, None, True, False, 1, False)  | 
566 |  | -    assert flags.idle_sms == num_idle_sms  | 
567 |  | - | 
568 | 471 | 
 
  | 
569 | 472 | @pytest.mark.parametrize("m, n, k, mode", [  | 
570 | 473 |     (1200, 704, 608, "ragged"),  | 
 | 
0 commit comments