Skip to content

Commit 352b348

Browse files
wdziurdzwhitneywhtsang
authored andcommitted
Fix device compability assert
Signed-off-by: Witold Dziurdz <[email protected]>
1 parent 6805186 commit 352b348

File tree

5 files changed

+101
-220
lines changed

5 files changed

+101
-220
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
474474
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
475475
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
476476
else:
477-
if torch.cuda.get_device_capability()[0] < 10:
477+
if is_cuda() and torch.cuda.get_device_capability()[0] < 10:
478478
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
479479
if block_m == 16:
480480
pytest.skip("PassManager::run failed from Triton compiler")

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def matmul_ogs(x, w, bias,
489489
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
490490
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
491491
w = wrap_torch_tensor(w, dtype=dtype)
492-
if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
492+
if w_has_mx and is_cuda() and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
493493
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
494494
if w_scale is not None and not isinstance(w_scale, Tensor):
495495
w_scale = Tensor(w_scale)
@@ -537,7 +537,7 @@ def matmul_ogs(x, w, bias,
537537
)
538538
has_gather_tma = has_gather and target_info.has_tma_gather()
539539
# hopper w/ mxfp4 doesn't support TMA
540-
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
540+
can_use_tma = can_use_tma and is_cuda() and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
541541
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
542542
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
543543
batch_size, M, N, w.shape[-2], routing_data,

0 commit comments

Comments
 (0)