Skip to content

Commit 6a8079b

Browse files
authored
[GPU] Enable SDPA by default (#24757)
### Details: - Enabled SDPA by default - Added indirect inputs support (copy of #24665) - Updated SDPA decomposition rule to cover only well-checked cases - Updated functional tests accordingly
1 parent 1d42420 commit 6a8079b

24 files changed

+847
-142
lines changed

src/core/include/openvino/op/scaled_dot_product_attention.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class OPENVINO_API ScaledDotProductAttention : public Op {
5050
return m_causal;
5151
}
5252

53+
void set_causal(bool causal) {
54+
m_causal = causal;
55+
}
56+
5357
private:
5458
bool m_causal = false;
5559
};
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "intel_gpu/op/sdpa.hpp"
8+
#include "openvino/core/node.hpp"
9+
#include "openvino/core/partial_shape.hpp"
10+
#include "openvino/op/op.hpp"
11+
12+
namespace ov {
13+
namespace intel_gpu {
14+
namespace op {
15+
16+
class IndirectSDPA : public ov::intel_gpu::op::SDPA {
17+
public:
18+
OPENVINO_OP("IndirectSDPA", "gpu_opset");
19+
20+
IndirectSDPA() = default;
21+
22+
IndirectSDPA(const ov::Output<Node>& Q,
23+
const ov::Output<Node>& K,
24+
const ov::Output<Node>& V,
25+
const ov::Output<Node>& beam_table,
26+
const bool is_causal,
27+
const int64_t indirect_axis,
28+
const std::vector<int64_t>& order_q,
29+
const std::vector<int64_t>& order_k,
30+
const std::vector<int64_t>& order_v,
31+
const std::vector<int64_t>& order_out,
32+
const ov::element::Type output_type = ov::element::undefined);
33+
34+
IndirectSDPA(const ov::Output<Node>& Q,
35+
const ov::Output<Node>& K,
36+
const ov::Output<Node>& V,
37+
const ov::Output<Node>& attn_mask,
38+
const ov::Output<Node>& beam_table,
39+
const bool is_causal,
40+
const int64_t indirect_axis,
41+
const std::vector<int64_t>& order_q,
42+
const std::vector<int64_t>& order_k,
43+
const std::vector<int64_t>& order_v,
44+
const std::vector<int64_t>& order_out,
45+
const ov::element::Type output_type = ov::element::undefined);
46+
47+
IndirectSDPA(const ov::Output<Node>& Q,
48+
const ov::Output<Node>& K,
49+
const ov::Output<Node>& V,
50+
const ov::Output<Node>& attn_mask,
51+
const ov::Output<Node>& scale,
52+
const ov::Output<Node>& beam_table,
53+
const bool is_causal,
54+
const int64_t indirect_axis,
55+
const std::vector<int64_t>& order_q,
56+
const std::vector<int64_t>& order_k,
57+
const std::vector<int64_t>& order_v,
58+
const std::vector<int64_t>& order_out,
59+
const ov::element::Type output_type = ov::element::undefined);
60+
61+
bool visit_attributes(ov::AttributeVisitor &visitor) override;
62+
void validate_and_infer_types() override;
63+
64+
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
65+
66+
ov::element::Type get_output_type() const { return m_output_type; }
67+
68+
int64_t get_indirect_axis() const { return m_indirect_axis; }
69+
70+
using ov::intel_gpu::op::SDPA::default_order;
71+
72+
protected:
73+
int64_t m_indirect_axis = -1;
74+
};
75+
76+
} // namespace op
77+
} // namespace intel_gpu
78+
} // namespace ov

src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,4 @@ REGISTER_FACTORY(internal, IndirectGemm);
285285
REGISTER_FACTORY(internal, Convolution);
286286
REGISTER_FACTORY(internal, Placeholder);
287287
REGISTER_FACTORY(internal, SDPA);
288+
REGISTER_FACTORY(internal, IndirectSDPA);

src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,31 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
1919
scaled_dot_product_attention(const primitive_id& id,
2020
const std::vector<cldnn::input_info> inputs,
2121
bool is_causal,
22+
int64_t indirect_axis = -1,
2223
const std::vector<int64_t>& input_q_transpose_order = {},
2324
const std::vector<int64_t>& input_k_transpose_order = {},
2425
const std::vector<int64_t>& input_v_transpose_order = {},
2526
const std::vector<int64_t>& output_transpose_order = {},
2627
const padding& output_padding = padding())
2728
: primitive_base(id, inputs, {output_padding})
2829
, is_causal(is_causal)
29-
, has_attn_mask_input(inputs.size() > 3)
30-
, has_scale_input(inputs.size() > 4)
30+
, indirect_axis(indirect_axis)
3131
, input_q_transpose_order(input_q_transpose_order)
3232
, input_k_transpose_order(input_k_transpose_order)
3333
, input_v_transpose_order(input_v_transpose_order)
34-
, output_transpose_order(output_transpose_order) {}
34+
, output_transpose_order(output_transpose_order) {
35+
auto data_inputs_num = inputs.size();
36+
if (indirect_axis != -1)
37+
data_inputs_num--;
3538

39+
has_attn_mask_input = data_inputs_num > 3;
40+
has_scale_input = data_inputs_num > 4;
41+
}
3642

3743
bool is_causal = false;
3844
bool has_attn_mask_input = false;
3945
bool has_scale_input = false;
46+
int64_t indirect_axis = -1;
4047

4148
std::vector<int64_t> input_q_transpose_order;
4249
std::vector<int64_t> input_k_transpose_order;
@@ -48,6 +55,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
4855
seed = hash_combine(seed, is_causal);
4956
seed = hash_combine(seed, has_attn_mask_input);
5057
seed = hash_combine(seed, has_scale_input);
58+
seed = hash_combine(seed, indirect_axis);
5159
seed = hash_range(seed, input_q_transpose_order.begin(), input_q_transpose_order.end());
5260
seed = hash_range(seed, input_k_transpose_order.begin(), input_k_transpose_order.end());
5361
seed = hash_range(seed, input_v_transpose_order.begin(), input_v_transpose_order.end());
@@ -64,6 +72,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
6472
return is_causal == rhs_casted.is_causal &&
6573
has_attn_mask_input == rhs_casted.has_attn_mask_input &&
6674
has_scale_input == rhs_casted.has_scale_input &&
75+
indirect_axis == rhs_casted.indirect_axis &&
6776
input_q_transpose_order == rhs_casted.input_q_transpose_order &&
6877
input_k_transpose_order == rhs_casted.input_k_transpose_order &&
6978
input_v_transpose_order == rhs_casted.input_v_transpose_order &&
@@ -75,6 +84,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
7584
ob << is_causal;
7685
ob << has_attn_mask_input;
7786
ob << has_scale_input;
87+
ob << indirect_axis;
7888
ob << input_q_transpose_order;
7989
ob << input_k_transpose_order;
8090
ob << input_v_transpose_order;
@@ -86,6 +96,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
8696
ib >> is_causal;
8797
ib >> has_attn_mask_input;
8898
ib >> has_scale_input;
99+
ib >> indirect_axis;
89100
ib >> input_q_transpose_order;
90101
ib >> input_k_transpose_order;
91102
ib >> input_v_transpose_order;

src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class debug_configuration {
129129
std::vector<std::string> forced_impl_types; // Force implementation type either ocl or onednn
130130
int max_kernels_per_batch; // Maximum number of kernels in a batch during compiling kernels
131131
int impls_cache_capacity; // The maximum number of entries in the kernel impl cache
132+
int enable_sdpa; // Allows to control SDPA decomposition
132133
int disable_async_compilation; // Disable async compilation
133134
int disable_winograd_conv; // Disable Winograd conv
134135
int disable_dynamic_impl; // Disable dynamic implementation

0 commit comments

Comments
 (0)