-
Notifications
You must be signed in to change notification settings - Fork 105
[ROCM 7.0] Improve Softmax accuracy #4174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/rocm-rel-7.0
Are you sure you want to change the base?
Conversation
Update the algorithm in fuse_attention for finding the attention subgraph. Current implementation was fusing unwanted instructions into the attention subgraph (eg. pointwise instructions that have multiple outputs) Implement find_instructions_between routine that only gives instructions directly connected to both start and end node Any additional fusions are handled already by the fuse_mlir pass
#4168) Adds explicit cast to int32 to avoid the case where models use int64 and we convert to fp32 in eliminate_data_type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR improves softmax accuracy in ROCm 7.0 by implementing FP32 softmax for better precision and introducing a new attention fusion pass to optimize attention patterns in transformer models.
Key changes:
- Adds FP32 softmax computation for lower precision types to improve numerical accuracy
- Introduces
fuse_attention
pass to detect and fuse attention patterns (GEMM-softmax-GEMM) into grouped operations - Refactors MLIR attention handling to work with the new group-based attention fusion
Reviewed Changes
Copilot reviewed 24 out of 25 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
src/rewrite_reduce.cpp |
Implements FP32 upcast/downcast for softmax operations when using lower precision types |
src/fuse_attention.cpp |
New pass that detects attention patterns and fuses them into group operations |
test/fuse_attention.cpp |
Tests for the new attention fusion pass |
src/targets/gpu/fuse_mlir.cpp |
Updates MLIR attention handling to work with group-based fusion |
test/gpu/fuse_mlir.cpp |
Updates tests to use the new group-based attention pattern |
docs/reference/MIGraphX-dev-env-vars.rst |
Documents new MIGRAPHX_DISABLE_FP32_SOFTMAX environment variable |
Comments suppressed due to low confidence (1)
src/targets/gpu/mlir.cpp:1027
- The parameter syntax is incorrect. It should use double braces for the parameter map:
{{"axes", new_reduce_axes}}
.
i, migraphx::make_op(reduce_op_name, {{"axes", new_reduce_axes}}), rsp_ins);
auto input_type = input->get_shape().type(); | ||
auto requires_upcast = not contains({shape::float_type, shape::double_type}, input_type); | ||
|
||
if(full_precision and requires_upcast) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic performs upcast/downcast operations but uses the original gemm1
instead of the upcasted input
for the subtract operation on line 1299. This causes the softmax computation to still use the original precision rather than FP32.
Copilot uses AI. Check for mistakes.
migraphx::make_op("sub"), pw_inputs[0], pw_inputs[1]); | ||
return pm->add_instruction(migraphx::make_op("exp"), sub); | ||
}); | ||
auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the fused attention test, the subtract operation uses gemm1
instead of where
, which is inconsistent with the manual softmax implementation pattern where the pointwise operations should be applied before the softmax computation.
auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); | |
auto sub = gm->add_instruction(migraphx::make_op("sub"), where, rmax); |
Copilot uses AI. Check for mistakes.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## release/rocm-rel-7.0 #4174 +/- ##
========================================================
+ Coverage 92.21% 92.23% +0.02%
========================================================
Files 545 548 +3
Lines 25107 25194 +87
========================================================
+ Hits 23152 23237 +85
- Misses 1955 1957 +2
🚀 New features to boost your workflow:
|
Attn refactor 4109
Fix for MHA in attn refactor 4152
Convert past sequence length to int32 before fusing kv-cache attention 4168
Fp32 softmax 4116
new env var doc and new organization 4138