Skip to content

[ROCM 7.0] Improve Softmax accuracy #4174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: release/rocm-rel-7.0
Choose a base branch
from

Conversation

causten
Copy link
Collaborator

@causten causten commented Jul 27, 2025

Attn refactor 4109
Fix for MHA in attn refactor 4152
Convert past sequence length to int32 before fusing kv-cache attention 4168
Fp32 softmax 4116
new env var doc and new organization 4138

shivadbhavsar and others added 5 commits July 26, 2025 17:39
Update the algorithm in fuse_attention for finding the attention subgraph.

Current implementation was fusing unwanted instructions into the attention subgraph (eg. pointwise instructions that have multiple outputs)
Implement find_instructions_between routine that only gives instructions directly connected to both start and end node
Any additional fusions are handled already by the fuse_mlir pass
#4168)

Adds explicit cast to int32 to avoid the case where models use int64 and we convert to fp32 in eliminate_data_type.
@causten causten requested a review from a team as a code owner July 27, 2025 22:11
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR improves softmax accuracy in ROCm 7.0 by implementing FP32 softmax for better precision and introducing a new attention fusion pass to optimize attention patterns in transformer models.

Key changes:

  • Adds FP32 softmax computation for lower precision types to improve numerical accuracy
  • Introduces fuse_attention pass to detect and fuse attention patterns (GEMM-softmax-GEMM) into grouped operations
  • Refactors MLIR attention handling to work with the new group-based attention fusion

Reviewed Changes

Copilot reviewed 24 out of 25 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/rewrite_reduce.cpp Implements FP32 upcast/downcast for softmax operations when using lower precision types
src/fuse_attention.cpp New pass that detects attention patterns and fuses them into group operations
test/fuse_attention.cpp Tests for the new attention fusion pass
src/targets/gpu/fuse_mlir.cpp Updates MLIR attention handling to work with group-based fusion
test/gpu/fuse_mlir.cpp Updates tests to use the new group-based attention pattern
docs/reference/MIGraphX-dev-env-vars.rst Documents new MIGRAPHX_DISABLE_FP32_SOFTMAX environment variable
Comments suppressed due to low confidence (1)

src/targets/gpu/mlir.cpp:1027

  • The parameter syntax is incorrect. It should use double braces for the parameter map: {{"axes", new_reduce_axes}}.
                i, migraphx::make_op(reduce_op_name, {{"axes", new_reduce_axes}}), rsp_ins);

auto input_type = input->get_shape().type();
auto requires_upcast = not contains({shape::float_type, shape::double_type}, input_type);

if(full_precision and requires_upcast)
Copy link
Preview

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic performs upcast/downcast operations but uses the original gemm1 instead of the upcasted input for the subtract operation on line 1299. This causes the softmax computation to still use the original precision rather than FP32.

Copilot uses AI. Check for mistakes.

migraphx::make_op("sub"), pw_inputs[0], pw_inputs[1]);
return pm->add_instruction(migraphx::make_op("exp"), sub);
});
auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax);
Copy link
Preview

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the fused attention test, the subtract operation uses gemm1 instead of where, which is inconsistent with the manual softmax implementation pattern where the pointwise operations should be applied before the softmax computation.

Suggested change
auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax);
auto sub = gm->add_instruction(migraphx::make_op("sub"), where, rmax);

Copilot uses AI. Check for mistakes.

Copy link

codecov bot commented Jul 28, 2025

Codecov Report

❌ Patch coverage is 95.65217% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/rewrite_reduce.cpp 84.21% 3 Missing ⚠️
src/fuse_attention.cpp 98.21% 1 Missing ⚠️
Additional details and impacted files
@@                   Coverage Diff                    @@
##           release/rocm-rel-7.0    #4174      +/-   ##
========================================================
+ Coverage                 92.21%   92.23%   +0.02%     
========================================================
  Files                       545      548       +3     
  Lines                     25107    25194      +87     
========================================================
+ Hits                      23152    23237      +85     
- Misses                     1955     1957       +2     
Files with missing lines Coverage Δ
src/include/migraphx/fuse_attention.hpp 100.00% <100.00%> (ø)
src/include/migraphx/instruction.hpp 100.00% <ø> (ø)
src/include/migraphx/match/softmax.hpp 100.00% <100.00%> (ø)
src/include/migraphx/matcher.hpp 94.25% <100.00%> (ø)
src/instruction.cpp 89.05% <100.00%> (+0.66%) ⬆️
src/fuse_attention.cpp 98.21% <98.21%> (ø)
src/rewrite_reduce.cpp 96.25% <84.21%> (-3.75%) ⬇️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants