Skip to content

Commit 086108d

Browse files
[mxfp] Reland remove col-major assert for mx weight (#5285)
Fixes #5269
2 parents 6620112 + 352b348 commit 086108d

File tree

5 files changed

+172
-229
lines changed

5 files changed

+172
-229
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class Case:
194194
x_transpose: bool = False
195195
w_transpose: bool = False
196196
y_transpose: bool = False
197+
colmajor_mxfp_weight: bool = True
197198

198199

199200
@pytest.mark.parametrize(
@@ -267,6 +268,7 @@ class Case:
267268
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
268269
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
269270
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
271+
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, colmajor_mxfp_weight=False),
270272
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
271273
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
272274
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
@@ -313,7 +315,7 @@ class Case:
313315
@pytest.mark.parametrize("has_y_gammas", [False, True])
314316
@pytest.mark.parametrize("is_persistent", [False, True])
315317
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
316-
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
318+
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
317319
x_transpose, w_transpose, y_transpose,
318320
device, opt_flags_scope):
319321
# TODO: remove when Triton FP8 supports proper RTNE
@@ -461,14 +463,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
461463
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
462464
mx_axis=mx_axis, num_warps=8)
463465
# downcast to mxfp
464-
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
465-
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
466-
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
467-
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
468-
w_scale_tri = wrap_torch_tensor(w_scale_tri)
469-
# convert layouts
470-
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
471-
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
466+
w_tri_orig = w_tri
467+
if colmajor_mxfp_weight:
468+
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
469+
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
470+
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
471+
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
472+
w_scale_tri = wrap_torch_tensor(w_scale_tri)
473+
# convert layouts
474+
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
475+
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
476+
else:
477+
if is_cuda() and torch.cuda.get_device_capability()[0] < 10:
478+
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
479+
if block_m == 16:
480+
pytest.skip("PassManager::run failed from Triton compiler")
481+
# TODO: swizzling for rowmajor
482+
483+
# A typical use case is we already quantized col-major weight,
484+
# and we want matmul with its transposed row-major weight w/o
485+
# requantization.
486+
487+
# put abs_max of each 32x32 block to diagonal so scales of transposed agree
488+
w_ndim = w_tri.ndim
489+
if w_ndim == 2:
490+
w_tri = w_tri.unsqueeze(0)
491+
BLOCK_SIZE = int(MXFP_BLOCK_SIZE)
492+
for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)):
493+
i_end = min(i+BLOCK_SIZE, w_tri.shape[1])
494+
j_end = min(j+BLOCK_SIZE, w_tri.shape[2])
495+
block = w_tri[e, i:i_end, j:j_end]
496+
m_abs = block.abs().max()
497+
i_len = i_end - i
498+
j_len = j_end - j
499+
min_len = min(i_len, j_len)
500+
signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1
501+
block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs
502+
if j_len > i_len:
503+
block[i_len - 1, i_len:] = signs[min_len:] * m_abs
504+
elif i_len > j_len:
505+
block[j_len:, j_len - 1] = signs[min_len:] * m_abs
506+
if w_ndim == 2:
507+
w_tri = w_tri.squeeze(0)
508+
509+
# matmul with rowmajor weight expects scale is separately
510+
# constructed (not much additional memory needed).
511+
_, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
512+
# reuse quantized value from colmajor
513+
w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis)
514+
w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous()
515+
w_tri = w_tri_rowmajor.data.mT
516+
517+
def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
518+
x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate")
519+
return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE)
520+
521+
# check if generated scale is transpose-invariant as intended construction
522+
# [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
523+
w_scale_tri_blocked = _pad_and_block(w_scale_tri)
524+
w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1]
525+
# [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
526+
w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor)
527+
w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1]
528+
assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked)
529+
assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked)
530+
assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT)
531+
472532
precision_opt.weight_scale = w_scale_tri
473533
epilogue = None
474534
if act_mxfp8:
@@ -477,7 +537,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
477537
is_input_batched = x_tri.ndim == 3
478538
y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape
479539
n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0]
480-
y_shape = (y_shape[0], n_rows, w_tri.shape[-1])
540+
y_shape = (y_shape[0], n_rows, w_tri_orig.shape[-1])
481541
if sindx is None or mode == "batched":
482542
if not is_input_batched:
483543
y_shape = (y_shape[1], y_shape[2])

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
1717
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
1818
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
19+
from .tensor_details.layout_details.strided import StridedLayout
1920
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
2021
from .specialize import specialize
2122
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor, RaggedTensorMetadata
@@ -483,12 +484,13 @@ def matmul_ogs(x, w, bias,
483484
w_scale = precision_config.weight_scale
484485
w_has_mx = w_scale is not None
485486
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
486-
if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
487487
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
488488
if not isinstance(w, Tensor):
489489
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
490490
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
491491
w = wrap_torch_tensor(w, dtype=dtype)
492+
if w_has_mx and is_cuda() and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
493+
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
492494
if w_scale is not None and not isinstance(w_scale, Tensor):
493495
w_scale = Tensor(w_scale)
494496
if w_scale is not None:
@@ -535,7 +537,7 @@ def matmul_ogs(x, w, bias,
535537
)
536538
has_gather_tma = has_gather and target_info.has_tma_gather()
537539
# hopper w/ mxfp4 doesn't support TMA
538-
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
540+
can_use_tma = can_use_tma and is_cuda() and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
539541
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
540542
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
541543
batch_size, M, N, w.shape[-2], routing_data,

0 commit comments

Comments
 (0)