Skip to content

Conversation

@niyunsheng
Copy link
Contributor

Summary

This PR fixes a precision issue in _block_rms_norm_backward_kernel when running in _CASTING_MODE_LLAMA.
It enforces FP32 accumulation for the weight gradient (dW), aligning its behavior with the existing _rms_norm_backward_kernel.

  • In _rms_norm_backward_kernel (Row-wise): The gradient dW_row is initialized as float32. When iterating through elements, the term dY_row * (X_row * rstd_row).to(X_dtype) (which is bfloat16) is added to dW_row. This operation performs an implicit cast to FP32 during the addition dW_row += val, effectively accumulating in high precision.

  • In _block_rms_norm_backward_kernel(Block-wise - The Bug): The code uses tl.sum for reduction:

# dY_row * (...) is bfloat16
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)), 0)

Here, tl.sum receives a bfloat16 tensor. Consequently, the reduction itself is performed in bfloat16. The precision loss occurs inside the reduction before the result is added to the FP32 dW_row. This leads to significant numerical errors for both small and large shapes due to the limited mantissa of bfloat16.

Testing Done

  • Hardware Type: A100-SXM4-80GB
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant