@@ -100,11 +100,15 @@ def torch_to_blocked_2d_K_groups(
100100 num_groups = group_offs .shape [0 ]
101101
102102 # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
103- # Triton kernel will use an upper bound of adding 4 padding cols to each group.
104- # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
105103 total_K_padded = total_K + num_groups * 4
106104 blocked_scales = x_scales .new_zeros (padded_M , total_K_padded )
107105
106+ # Flattened view for easier indexing when writing to subregions of memory
107+ blocked_scales_flat = blocked_scales .view (- 1 )
108+
109+ BLOCK_ROWS , BLOCK_COLS = 128 , 4
110+ output_stride_per_block = BLOCK_ROWS * BLOCK_COLS # 512
111+
108112 start_col_after_padding_list = [0 ]
109113 group_start_idx = 0
110114 for i , group_end_idx in enumerate (group_offs .tolist ()):
@@ -119,14 +123,37 @@ def torch_to_blocked_2d_K_groups(
119123 group_scales_blocked = to_blocked (group_scales )
120124 cols_after_padding = ceil_div (group_size , 4 ) * 4
121125
122- # Write output to subtensor
123- blocked_scales [
124- :,
125- prev_start_col_after_padding : prev_start_col_after_padding
126- + cols_after_padding ,
127- ] = group_scales_blocked .reshape (- 1 , cols_after_padding )
126+ num_row_blocks = ceil_div (M , 128 )
127+ num_col_blocks = cols_after_padding // 4
128128
129- # Calculate the start row after padding
129+ # Reshape blocked scales from flattened format to (num_row_blocks, num_col_blocks, ...)
130+ # so we can write each SF tile to its output buffer individually.
131+ group_scales_reshaped = group_scales_blocked .view (
132+ num_row_blocks , num_col_blocks , - 1
133+ )
134+ out_group_base_offset = prev_start_col_after_padding * padded_M
135+
136+ # For each SF tile, write to the output tensor
137+ for row_block in range (num_row_blocks ):
138+ for col_block in range (num_col_blocks ):
139+ block_data = group_scales_reshaped [row_block , col_block ]
140+
141+ stride_per_row_of_blocks_in_group = (
142+ num_col_blocks * output_stride_per_block
143+ )
144+ offset_in_group = (
145+ row_block * stride_per_row_of_blocks_in_group
146+ + col_block * output_stride_per_block
147+ )
148+ final_offset = out_group_base_offset + offset_in_group
149+
150+ # flattened (512,) for (128,4) sf tile
151+ block_flat = block_data .reshape (- 1 )
152+ blocked_scales_flat [
153+ final_offset : final_offset + output_stride_per_block
154+ ] = block_flat
155+
156+ # Calculate the start col after padding
130157 new_start_col = prev_start_col_after_padding + cols_after_padding
131158 start_col_after_padding_list .append (new_start_col )
132159
0 commit comments