@@ -194,6 +194,7 @@ class Case:
194194    x_transpose : bool  =  False 
195195    w_transpose : bool  =  False 
196196    y_transpose : bool  =  False 
197+     colmajor_mxfp_weight : bool  =  True 
197198
198199
199200@pytest .mark .parametrize ( 
@@ -267,6 +268,7 @@ class Case:
267268            Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 2 , 1 ), 
268269            Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ), 
269270            Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ), 
271+             Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , colmajor_mxfp_weight = False ), 
270272            Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 ), 
271273            Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ), 
272274            Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 8 , 4 ), 
@@ -313,7 +315,7 @@ class Case:
313315@pytest .mark .parametrize ("has_y_gammas" , [False , True ]) 
314316@pytest .mark .parametrize ("is_persistent" , [False , True ]) 
315317def  test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
316-             n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
318+             n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight ,  epilogue_subtile ,
317319            x_transpose , w_transpose , y_transpose ,
318320            device , opt_flags_scope ):
319321    # TODO: remove when Triton FP8 supports proper RTNE 
@@ -461,14 +463,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
461463            w_scale_layout , w_scale_layout_opts  =  layout .make_default_matmul_mxfp4_w_scale_layout (
462464                mx_axis = mx_axis , num_warps = 8 )
463465        # downcast to mxfp 
464-         w_tri , w_scale_tri  =  downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
465-         w_ref  =  upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
466-         w_tri_dtype  =  FP4  if  "float4"  in  weight_dtype_str  else  weight_dtype 
467-         w_tri  =  wrap_torch_tensor (w_tri , w_tri_dtype )
468-         w_scale_tri  =  wrap_torch_tensor (w_scale_tri )
469-         # convert layouts 
470-         w_tri  =  convert_layout (w_tri , w_layout , ** w_layout_opts )
471-         w_scale_tri  =  convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
466+         w_tri_orig  =  w_tri 
467+         if  colmajor_mxfp_weight :
468+             w_tri , w_scale_tri  =  downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
469+             w_ref  =  upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
470+             w_tri_dtype  =  FP4  if  "float4"  in  weight_dtype_str  else  weight_dtype 
471+             w_tri  =  wrap_torch_tensor (w_tri , w_tri_dtype )
472+             w_scale_tri  =  wrap_torch_tensor (w_scale_tri )
473+             # convert layouts 
474+             w_tri  =  convert_layout (w_tri , w_layout , ** w_layout_opts )
475+             w_scale_tri  =  convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
476+         else :
477+             if  is_cuda () and  torch .cuda .get_device_capability ()[0 ] <  10 :
478+                 pytest .skip ("transposed mxfp weight not supported with cuda capability < 10" )
479+             if  block_m  ==  16 :
480+                 pytest .skip ("PassManager::run failed from Triton compiler" )
481+             # TODO: swizzling for rowmajor 
482+ 
483+             # A typical use case is we already quantized col-major weight, 
484+             # and we want matmul with its transposed row-major weight w/o 
485+             # requantization. 
486+ 
487+             # put abs_max of each 32x32 block to diagonal so scales of transposed agree 
488+             w_ndim  =  w_tri .ndim 
489+             if  w_ndim  ==  2 :
490+                 w_tri  =  w_tri .unsqueeze (0 )
491+             BLOCK_SIZE  =  int (MXFP_BLOCK_SIZE )
492+             for  e , i , j  in  itertools .product (range (w_tri .shape [0 ]), range (0 , w_tri .shape [1 ], BLOCK_SIZE ), range (0 , w_tri .shape [2 ], BLOCK_SIZE )):
493+                 i_end  =  min (i + BLOCK_SIZE , w_tri .shape [1 ])
494+                 j_end  =  min (j + BLOCK_SIZE , w_tri .shape [2 ])
495+                 block  =  w_tri [e , i :i_end , j :j_end ]
496+                 m_abs  =  block .abs ().max ()
497+                 i_len  =  i_end  -  i 
498+                 j_len  =  j_end  -  j 
499+                 min_len  =  min (i_len , j_len )
500+                 signs  =  torch .randint (0 , 2 , (max (i_len , j_len ),), device = w_tri .device ) *  2  -  1 
501+                 block .diagonal (dim1 = - 2 , dim2 = - 1 )[:] =  signs [:min_len ] *  m_abs 
502+                 if  j_len  >  i_len :
503+                     block [i_len  -  1 , i_len :] =  signs [min_len :] *  m_abs 
504+                 elif  i_len  >  j_len :
505+                     block [j_len :, j_len  -  1 ] =  signs [min_len :] *  m_abs 
506+             if  w_ndim  ==  2 :
507+                 w_tri  =  w_tri .squeeze (0 )
508+ 
509+             # matmul with rowmajor weight expects scale is separately 
510+             # constructed (not much additional memory needed). 
511+             _ , w_scale_tri  =  downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
512+             # reuse quantized value from colmajor 
513+             w_tri_rowmajor , w_scale_tri_rowmajor  =  downcast_to_mxfp (w_tri .mT .contiguous (), weight_dtype , axis = mx_axis )
514+             w_ref  =  upcast_from_mxfp (w_tri_rowmajor , w_scale_tri_rowmajor , torch .bfloat16 , axis = mx_axis ).mT .contiguous ()
515+             w_tri  =  w_tri_rowmajor .data .mT 
516+ 
517+             def  _pad_and_block (x : torch .Tensor ) ->  torch .Tensor :
518+                 x  =  torch .nn .functional .pad (x , (0 , x .shape [- 1 ] %  BLOCK_SIZE ), mode = "replicate" )
519+                 return  x .view (* x .shape [:- 1 ], x .shape [- 1 ] //  BLOCK_SIZE , BLOCK_SIZE )
520+ 
521+             # check if generated scale is transpose-invariant as intended construction 
522+             # [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)] 
523+             w_scale_tri_blocked  =  _pad_and_block (w_scale_tri )
524+             w_scale_tri_sampled  =  w_scale_tri_blocked [..., 0 :1 ]
525+             # [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)] 
526+             w_scale_tri_rowmajor_blocked  =  _pad_and_block (w_scale_tri_rowmajor )
527+             w_scale_tri_rowmajor_sampled  =  w_scale_tri_rowmajor_blocked [..., 0 :1 ]
528+             assert  torch .equal (w_scale_tri_sampled .expand_as (w_scale_tri_blocked ), w_scale_tri_blocked )
529+             assert  torch .equal (w_scale_tri_rowmajor_sampled .expand_as (w_scale_tri_rowmajor_blocked ), w_scale_tri_rowmajor_blocked )
530+             assert  torch .equal (w_scale_tri_sampled .squeeze (- 1 ), w_scale_tri_rowmajor_sampled .squeeze (- 1 ).mT )
531+ 
472532        precision_opt .weight_scale  =  w_scale_tri 
473533    epilogue  =  None 
474534    if  act_mxfp8 :
@@ -477,7 +537,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
477537        is_input_batched  =  x_tri .ndim  ==  3 
478538        y_shape  =  x_tri .shape  if  is_input_batched  else  (1 ,) +  x_tri .shape 
479539        n_rows  =  y_shape [1 ] if  gindx  is  None  or  mode  ==  "batched"  else  gindx .dst_indx .shape [0 ]
480-         y_shape  =  (y_shape [0 ], n_rows , w_tri .shape [- 1 ])
540+         y_shape  =  (y_shape [0 ], n_rows , w_tri_orig .shape [- 1 ])
481541        if  sindx  is  None  or  mode  ==  "batched" :
482542            if  not  is_input_batched :
483543                y_shape  =  (y_shape [1 ], y_shape [2 ])
0 commit comments