Skip to content

Commit 84862d8

Browse files
committed
debug commit
1 parent 313cbb9 commit 84862d8

File tree

2 files changed

+10
-107
lines changed

2 files changed

+10
-107
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 8 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
# testing utilities
2222
from triton_kernels.testing import assert_close, compute_actual_scale
2323
# target-specific utilities
24-
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
25-
24+
from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4
25+
from icecream import ic
2626
# ---------------
2727
# initialize data
2828
# ---------------
@@ -168,100 +168,12 @@ class Case:
168168
", ".join(f.name for f in fields(Case)),
169169
[
170170
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
171-
# Zero-sized args:
172-
Case(0, 5, 7, "ragged", "float16", "float16"),
173-
Case(5, 0, 7, "ragged", "float16", "float16"),
174-
Case(5, 7, 0, "ragged", "float16", "float16"),
175-
Case(0, 5, 7, "batched", "float16", "float16"),
176-
Case(5, 0, 7, "batched", "float16", "float16"),
177-
Case(5, 7, 0, "batched", "float16", "float16"),
178-
# Non-mx types:
179-
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4),
180-
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2),
181-
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4),
182-
Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2),
183-
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),
184-
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),
185-
Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1),
186-
Case(16, 256, 256, "batched", "float16", "float16", 5, 1),
187-
Case(16, 256, 256, "ragged", "float16", "float16", 3, 1),
188-
Case(256, 256, 256, "ragged", "float16", "float16", 4, 1),
189-
Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3),
190-
Case(300, 400, 400, "batched", "float16", "float16", 5, 1),
191-
Case(300, 400, 400, "ragged", "float16", "float16"),
192-
Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"),
193-
Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1),
194-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1),
195-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2),
196-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4),
197-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2),
198-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2),
199-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2),
200-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2),
201-
Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1),
202-
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2),
203-
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9),
204-
# mx types:
205-
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1),
206-
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
207-
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1),
208-
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
209-
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
210-
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
211-
Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9),
212-
Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
213-
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4),
214-
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
215-
Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4),
216-
Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
217-
Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),
218-
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
219-
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
220-
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
221-
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1),
222-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
223-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
224-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
225-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
226-
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4),
227-
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
228-
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4),
229-
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
230-
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4),
231-
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
232-
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
233-
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False),
234-
Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
235-
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
236-
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
237171
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
238-
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
239-
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
240-
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
241-
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
242-
Case(300, 512, 512, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
243-
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
244-
Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4),
245-
Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
246-
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4),
247-
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
248-
# AMD
249-
Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"),
250-
Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1),
251-
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2),
252-
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2),
253-
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2),
254-
Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"),
255-
Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),
256-
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2),
257-
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2),
258-
] + [
259-
Case(320, 400, 400, mode, dtype, dtype, x_transpose=x_transpose, w_transpose=w_transpose, y_transpose=y_transpose)
260-
for mode in ("batched", "ragged")
261-
for dtype in ("float16", "float8_e5m2")
262-
for x_transpose in (False, True)
263-
for w_transpose in (False, True)
264-
for y_transpose in (False, True)
172+
#Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
173+
Case(1000, 704, 800, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9),
174+
#Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
175+
176+
#Case(1111, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),
265177
]
266178
],
267179
)
@@ -355,6 +267,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
355267
"block_k": 256
356268
})
357269

270+
ic(constraints)
358271
opt_flags.update_opt_flags_constraints(constraints)
359272

360273
weight_mxfp = weight_dtype_str.startswith("mx")
@@ -555,16 +468,6 @@ def _make_tensor(shape, dtype, trans):
555468
)
556469

557470

558-
def test_set_idle_sms():
559-
if not is_cuda():
560-
pytest.skip("Only supported on CUDA")
561-
from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags
562-
num_idle_sms = 24
563-
matmul_ogs_set_idle_sms(num_idle_sms)
564-
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
565-
1, 1024, 1024, 1024, None, True, False, 1, False)
566-
assert flags.idle_sms == num_idle_sms
567-
568471

569472
@pytest.mark.parametrize("m, n, k, mode", [
570473
(1200, 704, 608, "ragged"),

scripts/test-triton.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ run_core_tests() {
406406
echo "***************************************************"
407407
echo "****** Running Triton Core tests ******"
408408
echo "***************************************************"
409-
run_minicore_tests
409+
run_triton_kernels_tests
410410
run_mxfp_tests
411411
run_scaled_dot_tests
412412
}
@@ -685,7 +685,7 @@ test_triton() {
685685
run_core_tests
686686
else
687687
if [ "$TEST_MINICORE" = true ]; then
688-
run_minicore_tests
688+
run_triton_kernels_tests
689689
fi
690690
if [ "$TEST_MXFP" = true ]; then
691691
run_mxfp_tests

0 commit comments

Comments
 (0)