fixbug, ensure FP32 accumulation for dW in Llama-mode RMSNorm backward #950
+1
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes a precision issue in
_block_rms_norm_backward_kernelwhen running in_CASTING_MODE_LLAMA.It enforces FP32 accumulation for the weight gradient (
dW), aligning its behavior with the existing_rms_norm_backward_kernel.In
_rms_norm_backward_kernel(Row-wise): The gradientdW_rowis initialized asfloat32. When iterating through elements, the termdY_row * (X_row * rstd_row).to(X_dtype)(which isbfloat16) is added todW_row. This operation performs an implicit cast to FP32 during the additiondW_row += val, effectively accumulating in high precision.In
_block_rms_norm_backward_kernel(Block-wise - The Bug): The code usestl.sumfor reduction:Here,
tl.sumreceives abfloat16tensor. Consequently, the reduction itself is performed inbfloat16. The precision loss occurs inside the reduction before the result is added to the FP32dW_row. This leads to significant numerical errors for both small and large shapes due to the limited mantissa ofbfloat16.Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence