-
Notifications
You must be signed in to change notification settings - Fork 3k
SDPAFusion with sinks #31838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SDPAFusion with sinks #31838
Conversation
|
Could you add a decompose pass too ? |
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Outdated
Show resolved
Hide resolved
…tions/sdpa_fusion.cpp Co-authored-by: Roman Kazantsev <[email protected]>
| SDPAReshapeFusion(); | ||
| }; | ||
|
|
||
| class TRANSFORMATIONS_API SDPAFusionMatcherSinks : public ov::pass::MatcherPass { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please give a picture of the pattern to be fused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There're a lot of SDPA variations to fuse, so listing all of them is not feasible. Added as screenshot to the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to share some particular case for it and give a note what other variations can be. Will be much helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added ascii graph
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Outdated
Show resolved
Hide resolved
…tions/sdpa_fusion.cpp Co-authored-by: Roman Kazantsev <[email protected]>
…ino into sdpa_fusion_gpt_oss
...formations/src/transformations/op_conversions/scaled_dot_product_attention_decomposition.cpp
Outdated
Show resolved
Hide resolved
…scaled_dot_product_attention_decomposition.cpp
src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp
Outdated
Show resolved
Hide resolved
mitruska
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The patterns LGTM. Would be good to refactor for reusage of the common code part with existing SDPAFusion (not a blocker, possible as follow up improvement).
Minor comment for graph visualization in the tests.
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Show resolved
Hide resolved
src/common/transformations/tests/op_conversions/scaled_dot_product_decomposition_test.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/tests/op_conversions/scaled_dot_product_decomposition_test.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/tests/op_conversions/scaled_dot_product_decomposition_test.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/tests/op_conversions/scaled_dot_product_decomposition_test.cpp
Outdated
Show resolved
Hide resolved
| // Pattern 1: axis = -1 (last axis) | ||
| // Pattern 2: axis = rank size - 1 (also means last axis for static rank inputs) | ||
| auto axis_predicate = ([](const ov::Output<ov::Node>& node) { | ||
| auto softmax = std::dynamic_pointer_cast<ov::op::v8::Softmax>(node.get_node_shared_ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaik, we need to use ov::as_type_ptr. This is important for OV.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I heard, there were new findings that using std::dynamic_pointer_cast is also safe now, but I'll change to ov::as_type_ptr, no problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Who said it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, don't remember right now :)
| auto k_node = pm.at(k); | ||
| auto v_node = pm.at(v); | ||
|
|
||
| if (pm.at(mask).get_partial_shape().size() > 4) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not clear. Is it a rank you want to compare? Is it static?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're comparing the rank here, changed.
Yeah, it's static, we check this above using the pattern predicate: auto mask = any_input(has_static_rank());
…ino into sdpa_fusion_gpt_oss
5d822fd
Details:
Add SDPAFusionSinks transformation. The pattern is hard-coded and added as a separate transformation for the gpt-oss model because of high demand and issues with integrating it to the existing SDPAFusion pattern
gpt-oss SDPA pattern to fuse:

Tickets: