Skip to content

Commit b15065b

Browse files
[mxfp8 moe training] fix torch ref impl of SF blocked layout per group along K
1 parent 3955b6c commit b15065b

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

test/prototype/moe_training/test_kernels.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,6 @@ def test_mxfp8_per_group_blocked_scales_3d(
313313
)
314314

315315

316-
@pytest.mark.skip(
317-
"Temporarily disable and use e2e training numerical tests instead. See: https://github.com/pytorch/ao/pull/2990#discussion_r2354167396"
318-
)
319316
@skip_if_rocm("ROCm enablement in progress")
320317
@pytest.mark.parametrize("m", [256, 512, 1024, 5120])
321318
@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384])

torchao/prototype/moe_training/kernels/mxfp8/quant.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,15 @@ def torch_to_blocked_2d_K_groups(
100100
num_groups = group_offs.shape[0]
101101

102102
# Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
103-
# Triton kernel will use an upper bound of adding 4 padding cols to each group.
104-
# (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
105103
total_K_padded = total_K + num_groups * 4
106104
blocked_scales = x_scales.new_zeros(padded_M, total_K_padded)
107105

106+
# Flattened view for easier indexing when writing to subregions of memory
107+
blocked_scales_flat = blocked_scales.view(-1)
108+
109+
BLOCK_ROWS, BLOCK_COLS = 128, 4
110+
output_stride_per_block = BLOCK_ROWS * BLOCK_COLS # 512
111+
108112
start_col_after_padding_list = [0]
109113
group_start_idx = 0
110114
for i, group_end_idx in enumerate(group_offs.tolist()):
@@ -119,14 +123,37 @@ def torch_to_blocked_2d_K_groups(
119123
group_scales_blocked = to_blocked(group_scales)
120124
cols_after_padding = ceil_div(group_size, 4) * 4
121125

122-
# Write output to subtensor
123-
blocked_scales[
124-
:,
125-
prev_start_col_after_padding : prev_start_col_after_padding
126-
+ cols_after_padding,
127-
] = group_scales_blocked.reshape(-1, cols_after_padding)
126+
num_row_blocks = ceil_div(M, 128)
127+
num_col_blocks = cols_after_padding // 4
128128

129-
# Calculate the start row after padding
129+
# Reshape blocked scales from flattened format to (num_row_blocks, num_col_blocks, ...)
130+
# so we can write each SF tile to its output buffer individually.
131+
group_scales_reshaped = group_scales_blocked.view(
132+
num_row_blocks, num_col_blocks, -1
133+
)
134+
out_group_base_offset = prev_start_col_after_padding * padded_M
135+
136+
# For each SF tile, write to the output tensor
137+
for row_block in range(num_row_blocks):
138+
for col_block in range(num_col_blocks):
139+
block_data = group_scales_reshaped[row_block, col_block]
140+
141+
stride_per_row_of_blocks_in_group = (
142+
num_col_blocks * output_stride_per_block
143+
)
144+
offset_in_group = (
145+
row_block * stride_per_row_of_blocks_in_group
146+
+ col_block * output_stride_per_block
147+
)
148+
final_offset = out_group_base_offset + offset_in_group
149+
150+
# flattened (512,) for (128,4) sf tile
151+
block_flat = block_data.reshape(-1)
152+
blocked_scales_flat[
153+
final_offset : final_offset + output_stride_per_block
154+
] = block_flat
155+
156+
# Calculate the start col after padding
130157
new_start_col = prev_start_col_after_padding + cols_after_padding
131158
start_col_after_padding_list.append(new_start_col)
132159

0 commit comments

Comments
 (0)