Skip to content

Commit 8d47813

Browse files
[mxfp8 moe training] _to_mxfp8_then_scaled_grouped_mm wrapper that accepts keyword args (#3561)
1 parent 2319156 commit 8d47813

File tree

1 file changed

+48
-2
lines changed

1 file changed

+48
-2
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,13 @@ def forward(
297297
wgrad_with_hp: bool = False,
298298
scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL,
299299
) -> torch.Tensor:
300-
# torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
301300
assert A.ndim == 2, "A must be 2D"
302301
assert B_t.ndim == 3, "B must be 3D"
303302
assert block_size == 32, "Only block_size=32 is supported"
304303
assert offs is not None, "offs must be provided for 2d-2d and 2d-3d grouped mm"
304+
assert out_dtype in (torch.bfloat16, torch.float32), (
305+
"out_dtype must be bfloat16 or float32"
306+
)
305307

306308
# A_data shape: (M, K)
307309
# A_scale shape: (M, K//block_size)
@@ -682,5 +684,49 @@ def round_up(x, y):
682684

683685

684686
# Aliases for convenience/clarity
685-
_to_mxfp8_then_scaled_grouped_mm = _MXFP8GroupedMM.apply
687+
def _to_mxfp8_then_scaled_grouped_mm(
688+
A: torch.Tensor,
689+
B_t: torch.Tensor,
690+
offs: Optional[torch.Tensor] = None,
691+
block_size: int = 32,
692+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
693+
emulated: bool = False,
694+
use_triton_for_dim0_cast: bool = True,
695+
wgrad_with_hp: bool = False,
696+
scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL,
697+
) -> torch.Tensor:
698+
"""
699+
Differentiable mxfp8 grouped gemm with dynamic mxfp8 quantization.
700+
701+
Args:
702+
- A (bf16/float32 torch.Tensor): The first high-precision input tensor,
703+
which must be a 2D tensor of shape (M * num_groups, K)
704+
and in row-major memory layout.
705+
- B_t (bf16/float32 torch.Tensor): The second high-precision input tensor
706+
which must be 3D, which must be shape (G, K, N)
707+
and in "per group column-major memory" layout (i.e., strides of (N*K, 1, N)).
708+
- offs (int32 torch.Tensor): The offsets to use to mark the end index of each group along the dim0 of the A tensor.
709+
- block_size (int): The block size to use for mxpf8 quantization. Currently only 32 is supported.
710+
- out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Default is torch.bfloat16.
711+
- emulated (bool): Whether to use the emulated mxpf8 scaled grouped mm kernel (for testing).
712+
- use_triton_for_dim0_cast (bool): Whether to use Triton for the dim0 cast. Default true. If false, use torch native implementation.
713+
- wgrad_with_hp (bool): Whether to compute weight gradients in high precision.
714+
- scale_calculation_mode (ScaleCalculationMode): The mode to use for scale calculation.
715+
716+
Returns:
717+
- out (torch.Tensor): The result of the mxpf8 scaled grouped gemm.
718+
"""
719+
return _MXFP8GroupedMM.apply(
720+
A,
721+
B_t,
722+
offs,
723+
block_size,
724+
out_dtype,
725+
emulated,
726+
use_triton_for_dim0_cast,
727+
wgrad_with_hp,
728+
scale_calculation_mode,
729+
)
730+
731+
686732
_to_fp8_rowwise_then_scaled_grouped_mm = _Float8GroupedMM.apply

0 commit comments

Comments
 (0)