Skip to content

Conversation

@ruro
Copy link

@ruro ruro commented Nov 21, 2025

Changes

  • added ONNXAttentionMetatype for the opset 23 Attention ONNX node
  • fixed scaled_dot_product_attention quantization in torch2 for the case when Q, K and V are parallel edges coming from the same input node

Reason for changes

See #3750

Related tickets

Fixes #3750

Tests

  • tests/onnx/quantization/test_graphs.py::test_synthetic_models_graph[AttentionModel] attention_model dot

  • tests/torch2/function_hook/quantization/test_quantized_graphs.py::test_quantized_graphs[unbind_scaled_dot_product_attention_model] unbind_scaled_dot_product_attention_model dot

@ruro ruro force-pushed the fix_onnx_attention_torch_sdpa_handling branch from 734ba64 to 791f962 Compare November 21, 2025 09:26
@github-actions github-actions bot added the NNCF ONNX Pull requests that updates NNCF ONNX label Nov 21, 2025
@ruro ruro marked this pull request as ready for review November 21, 2025 09:36
@ruro ruro requested a review from a team as a code owner November 21, 2025 09:36
@ruro
Copy link
Author

ruro commented Nov 21, 2025

Hm. iirc onnx added support for opset 23 in version 1.18.0. So the new test is currently failing in CI due to

onnx==1.17.0; python_version < '3.13'
onnx==1.18.0; python_version >= '3.13'

Do you have any preferences if I should mark this test as

@pytest.mark.skipif(
    version.parse(onnx.__version__) < version.parse("1.18.0"),
    reason="Opset 23 was added in onnx 1.18.0",
)

or bump the version or something else?

@andrey-churkin andrey-churkin self-assigned this Nov 25, 2025
@andrey-churkin
Copy link
Contributor

Hm. iirc onnx added support for opset 23 in version 1.18.0. So the new test is currently failing in CI due to

onnx==1.17.0; python_version < '3.13'
onnx==1.18.0; python_version >= '3.13'

Do you have any preferences if I should mark this test as

@pytest.mark.skipif(
    version.parse(onnx.__version__) < version.parse("1.18.0"),
    reason="Opset 23 was added in onnx 1.18.0",
)

or bump the version or something else?

Hi @ruro, thanks for your contribution. We currently support multiple versions of ONNX, and the Attention operator was added in opset 23, which corresponds to ONNX 1.18.0. I believe we should run this test only for ONNX versions >= 1.18.0.

@ruro ruro force-pushed the fix_onnx_attention_torch_sdpa_handling branch from 791f962 to 9af3bfb Compare November 27, 2025 11:41
Comment on lines +239 to +242
input_port_ids = [input_edge.input_port_id] + input_edge.parallel_input_port_ids
node_name = nncf_node.node_name
for input_port_id in input_port_ids:
allowed_pre_hook_insertion_points.append(PreHookInsertionPoint(node_name, input_port_id))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ruro Could you please briefly explain why these changes are necessary?

@daniil-lyakhov Please take a look

Copy link
Author

@ruro ruro Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've outlined my reasoning in the last two comments in #3750. The short version is that parallel edges aren't directly represented in the PTNNCFGraph (because it's not a Multi graph and doesn't allow repeated edges), but are instead stored in the parallel_input_port_ids property.

In this case, unbind has 3 outputs that are passed as q, k and v inputs of the sdpa node. Each of these 3 edges should be considered separately for the purposes of quantizer insertion/propagation, but the previous logic only added insertion points for "real" edges, ignoring any extra parallel edges.

Let me know if anything is unclear.

Copy link
Collaborator

@daniil-lyakhov daniil-lyakhov Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great contribution! Could you please share a netron/ nncf graph visualization of the brand new supported subgraph? (nncf graph visualization api: https://github.com/openvinotoolkit/nncf/blob/develop/src/nncf/common/graph/graph.py#L611-L613)

This part of the code is the core logic of the NNCF, we need to figure out all possible side effects of this change

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, what you mean by "netron/nncf graph". The second image in the PR body is the expected graph for unbind+sdpa after applying quantization. Does that work?

Copy link
Author

@ruro ruro Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, here are the before and after graphs, obtained by performing a torch.onnx.export (with opset_version=23) of a quantized timm.layers.attention.Attention module:

before after

(The edge without the q/dq nodes is the V input of Attention as expected)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

NNCF ONNX Pull requests that updates NNCF ONNX

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MULTIHEAD_ATTENTION_OUTPUT ignored patterns don't match "proper" SDPA / Attention

3 participants