@@ -297,11 +297,13 @@ def forward(
297297 wgrad_with_hp : bool = False ,
298298 scale_calculation_mode : ScaleCalculationMode = ScaleCalculationMode .RCEIL ,
299299 ) -> torch .Tensor :
300- # torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
301300 assert A .ndim == 2 , "A must be 2D"
302301 assert B_t .ndim == 3 , "B must be 3D"
303302 assert block_size == 32 , "Only block_size=32 is supported"
304303 assert offs is not None , "offs must be provided for 2d-2d and 2d-3d grouped mm"
304+ assert out_dtype in (torch .bfloat16 , torch .float32 ), (
305+ "out_dtype must be bfloat16 or float32"
306+ )
305307
306308 # A_data shape: (M, K)
307309 # A_scale shape: (M, K//block_size)
@@ -682,5 +684,49 @@ def round_up(x, y):
682684
683685
684686# Aliases for convenience/clarity
685- _to_mxfp8_then_scaled_grouped_mm = _MXFP8GroupedMM .apply
687+ def _to_mxfp8_then_scaled_grouped_mm (
688+ A : torch .Tensor ,
689+ B_t : torch .Tensor ,
690+ offs : Optional [torch .Tensor ] = None ,
691+ block_size : int = 32 ,
692+ out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
693+ emulated : bool = False ,
694+ use_triton_for_dim0_cast : bool = True ,
695+ wgrad_with_hp : bool = False ,
696+ scale_calculation_mode : ScaleCalculationMode = ScaleCalculationMode .RCEIL ,
697+ ) -> torch .Tensor :
698+ """
699+ Differentiable mxfp8 grouped gemm with dynamic mxfp8 quantization.
700+
701+ Args:
702+ - A (bf16/float32 torch.Tensor): The first high-precision input tensor,
703+ which must be a 2D tensor of shape (M * num_groups, K)
704+ and in row-major memory layout.
705+ - B_t (bf16/float32 torch.Tensor): The second high-precision input tensor
706+ which must be 3D, which must be shape (G, K, N)
707+ and in "per group column-major memory" layout (i.e., strides of (N*K, 1, N)).
708+ - offs (int32 torch.Tensor): The offsets to use to mark the end index of each group along the dim0 of the A tensor.
709+ - block_size (int): The block size to use for mxpf8 quantization. Currently only 32 is supported.
710+ - out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Default is torch.bfloat16.
711+ - emulated (bool): Whether to use the emulated mxpf8 scaled grouped mm kernel (for testing).
712+ - use_triton_for_dim0_cast (bool): Whether to use Triton for the dim0 cast. Default true. If false, use torch native implementation.
713+ - wgrad_with_hp (bool): Whether to compute weight gradients in high precision.
714+ - scale_calculation_mode (ScaleCalculationMode): The mode to use for scale calculation.
715+
716+ Returns:
717+ - out (torch.Tensor): The result of the mxpf8 scaled grouped gemm.
718+ """
719+ return _MXFP8GroupedMM .apply (
720+ A ,
721+ B_t ,
722+ offs ,
723+ block_size ,
724+ out_dtype ,
725+ emulated ,
726+ use_triton_for_dim0_cast ,
727+ wgrad_with_hp ,
728+ scale_calculation_mode ,
729+ )
730+
731+
686732_to_fp8_rowwise_then_scaled_grouped_mm = _Float8GroupedMM .apply
0 commit comments