Skip to content

Commit 82bbf02

Browse files
Revert GPT-OSS ROPE fusion (#33183)
### Details: - *This reverts commit 4312e0a.* - *...* ### Tickets: - *CVS-177675*
1 parent 161cd2c commit 82bbf02

File tree

14 files changed

+20
-218
lines changed

14 files changed

+20
-218
lines changed

src/common/transformations/include/ov_ops/rotary_positional_embeddings.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class TRANSFORMATIONS_API RoPE : public Op {
3333
bool is_qwen = false; // Qwen is special which overrides other setting
3434
bool use_rope_cache = false; // use precomputed RoPE cache for trigonometric values (cosine and sine)
3535
bool support_3d_rope = false; // use same logic as RoPEFusionGPTNEOX(4), used by gpu plugin
36-
size_t cos_sin_ndims = 0; // last dimension of con/sin table
3736
size_t head_cnt = 0;
3837
size_t head_size = 0;
3938
int gather_position_arg_id =

src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class TRANSFORMATIONS_API RoPEFusionIOSlicing;
2121
class TRANSFORMATIONS_API RoPEFusionPreprocess;
2222
class TRANSFORMATIONS_API RoPEFusionCosSinPreprocess;
2323
class TRANSFORMATIONS_API RoPEShareCosSin;
24-
class TRANSFORMATIONS_API RoPEFusionGPTOSS;
2524

2625
} // namespace pass
2726
} // namespace ov
@@ -92,12 +91,6 @@ class ov::pass::RoPEShareCosSin : public ov::pass::MatcherPass {
9291
std::vector<std::shared_ptr<ov::Node>> m_shared_inputs{2, nullptr};
9392
};
9493

95-
class ov::pass::RoPEFusionGPTOSS : public ov::pass::MatcherPass {
96-
public:
97-
OPENVINO_MATCHER_PASS_RTTI("RoPEFusionGPTOSS");
98-
RoPEFusionGPTOSS();
99-
};
100-
10194
/**
10295
* @ingroup ov_transformation_common_api
10396
* @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation

src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr<ov::Model>& model)
6161
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTNEOX>(4);
6262
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTNEOX>(3);
6363
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTJ>();
64-
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTOSS>();
6564
// optional heads & tails are fused in separate matcher pass,
6665
// after RoPENode has been created.
6766
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionCosSinPreprocess>();
@@ -1112,75 +1111,3 @@ ov::pass::RoPEShareCosSin::RoPEShareCosSin() {
11121111
auto m = std::make_shared<pattern::Matcher>(result, matcher_name);
11131112
this->register_matcher(m, callback);
11141113
}
1115-
1116-
ov::pass::RoPEFusionGPTOSS::RoPEFusionGPTOSS() {
1117-
using namespace ov::op::util;
1118-
MATCHER_SCOPE(RoPEFusionGPTOSS);
1119-
1120-
// gpt-oss style
1121-
// first_half, second_half = torch.chunk(x, 2, dim=-1)
1122-
// first_ = first_half * cos - second_half * sin
1123-
// second_ = second_half * cos + first_half * sin
1124-
// return torch.cat((first_, second_), dim=-1)
1125-
auto x = pattern::any_input(pattern::rank_equals(4));
1126-
auto t_cos = pattern::any_input(pattern::shape_matches("[?, 1, ?, half_ndims]"));
1127-
auto t_sin = pattern::any_input(pattern::shape_matches("[?, 1, ?, half_ndims]"));
1128-
1129-
auto vsplit_out0 = pattern::wrap_type<op::v1::VariadicSplit>(
1130-
{x, -1, {"half_ndims", "?"}},
1131-
pattern::output_index_matches(0) && pattern::shape_matches("[?, ?, ?, half_ndims]"));
1132-
auto vsplit_out1 = pattern::wrap_type<op::v1::VariadicSplit>(
1133-
{x, -1, {"half_ndims", "?"}},
1134-
pattern::output_index_matches(1) && pattern::shape_matches("[?, ?, ?, half_ndims]"));
1135-
auto first_half_mul_cos = pattern::wrap_type<v1::Multiply>({vsplit_out0, t_cos}, {{"auto_broadcast", "numpy"}});
1136-
auto second_half_mul_sin = pattern::wrap_type<v1::Multiply>({vsplit_out1, t_sin}, {{"auto_broadcast", "numpy"}});
1137-
auto neg = pattern::wrap_type<v1::Multiply>({second_half_mul_sin, -1.0f}, {{"auto_broadcast", "numpy"}});
1138-
auto sub_Subtract = pattern::wrap_type<v1::Add>({first_half_mul_cos, neg}, {{"auto_broadcast", "numpy"}});
1139-
1140-
auto second_half_mul_cos = pattern::wrap_type<v1::Multiply>({vsplit_out1, t_cos}, {{"auto_broadcast", "numpy"}});
1141-
auto first_half_mul_sin = pattern::wrap_type<v1::Multiply>({vsplit_out0, t_sin}, {{"auto_broadcast", "numpy"}});
1142-
auto add_Add =
1143-
pattern::wrap_type<v1::Add>({second_half_mul_cos, first_half_mul_sin}, {{"auto_broadcast", "numpy"}});
1144-
auto concat_result = pattern::wrap_type<opset1::Concat>({sub_Subtract, add_Add}, {{"axis", -1}});
1145-
1146-
auto result = concat_result;
1147-
1148-
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
1149-
const auto& pattern_map = m.get_pattern_value_map();
1150-
auto root = m.get_match_root();
1151-
const auto& x_val = pattern_map.at(x);
1152-
const auto& v_cos = pattern_map.at(t_cos);
1153-
1154-
auto symbols = m.get_symbols();
1155-
const auto& half_ndims = symbols["half_ndims"];
1156-
if (!half_ndims.is_integer()) {
1157-
return false;
1158-
}
1159-
1160-
op::internal::RoPE::Config config;
1161-
OutputVector new_args;
1162-
config.rotary_ndims = 2ul * static_cast<size_t>(half_ndims.i());
1163-
config.cos_sin_ndims = static_cast<size_t>(half_ndims.i());
1164-
1165-
new_args.push_back(x_val);
1166-
new_args.push_back(v_cos);
1167-
new_args.push_back(pattern_map.at(t_sin));
1168-
auto new_node = std::make_shared<internal::RoPE>(new_args, config);
1169-
new_node->set_friendly_name(root->get_friendly_name());
1170-
ov::copy_runtime_info({pattern_map.at(neg).get_node_shared_ptr(),
1171-
pattern_map.at(sub_Subtract).get_node_shared_ptr(),
1172-
pattern_map.at(first_half_mul_cos).get_node_shared_ptr(),
1173-
pattern_map.at(first_half_mul_sin).get_node_shared_ptr(),
1174-
pattern_map.at(second_half_mul_cos).get_node_shared_ptr(),
1175-
pattern_map.at(second_half_mul_sin).get_node_shared_ptr(),
1176-
pattern_map.at(add_Add).get_node_shared_ptr(),
1177-
pattern_map.at(result).get_node_shared_ptr()},
1178-
new_node);
1179-
ov::replace_node(root, new_node);
1180-
register_new_node(new_node);
1181-
return true;
1182-
};
1183-
1184-
auto m = std::make_shared<ov::pass::pattern::Matcher>(result, matcher_name);
1185-
this->register_matcher(m, callback);
1186-
}

src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,10 @@ void jit_rotary_kernel<isa>::rotary_half(size_t step) {
8989
vfmsub231ps(vmm_dst0, vmm_cos, vmm_src0);
9090
store(reg_dst, vmm_dst0, m_jcp.dst_prc, step);
9191

92-
// cos[i + sin_cos_offset]
93-
// if con/sin table is not same size with input, it is reused for both halves.
94-
const bool shift_cos_sin = m_jcp.cos_sin_ndims != half_rotary_ndims;
95-
if (shift_cos_sin) {
96-
load(vmm_cos, reg_cos, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
97-
}
98-
// sin[i + sin_cos_offset]
99-
if (shift_cos_sin) {
100-
load(vmm_sin, reg_sin, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
101-
}
92+
// cos[i + halfRotaryNdims]
93+
load(vmm_cos, reg_cos, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
94+
// sin[i + halfRotaryNdims]
95+
load(vmm_sin, reg_sin, ov::element::f32, step, false, half_rotary_ndims * sizeof(float));
10296
// cos[i + half_rotary_dims] * src1
10397
uni_vmulps(vmm_dst0, vmm_cos, vmm_src1);
10498
// cos[i + half_rotary_dims] * src1 + sin[i + half_rotary_dims] * src0

src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ struct jit_rotary_compile_params {
2727
ov::element::Type src_prc;
2828
ov::element::Type dst_prc;
2929
size_t rotary_ndims = 0UL;
30-
size_t cos_sin_ndims = 0UL;
3130
bool interleave = false;
3231
bool mix_cos_sin = false;
3332
};

src/plugins/intel_cpu/src/nodes/paged_attn.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "onednn/iml_type_mapper.h"
2525
#include "openvino/core/except.hpp"
2626
#include "openvino/core/node.hpp"
27+
#include "openvino/core/shape.hpp"
2728
#include "openvino/core/type/element_type.hpp"
2829
#include "openvino/runtime/system_conf.hpp"
2930
#include "shape_inference/shape_inference_internal_dyn.hpp"
@@ -175,7 +176,7 @@ void PagedAttention::initSupportedPrimitiveDescriptors() {
175176
// sinks, float, [1, H, 1, 1]
176177
config.inConfs[PagedAttentionExecutor::ID_SINKS].setMemDesc(
177178
creatorsMap.at(LayoutType::ncsp)
178-
->createSharedDesc(ov::element::f32, getInputShapeAtPort(PagedAttentionExecutor::ID_SINKS)));
179+
->createSharedDesc(rtPrecision, getInputShapeAtPort(PagedAttentionExecutor::ID_SINKS)));
179180

180181
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any);
181182
}
@@ -300,6 +301,10 @@ bool PagedAttention::isSupportedOperation(const std::shared_ptr<const ov::Node>&
300301
return false;
301302
}
302303
}
304+
if (ov::shape_size(op->get_input_shape(PagedAttentionExecutor::ID_SINKS)) != 0) {
305+
errorMessage = "PageAttn sinks input is not supported yet";
306+
return false;
307+
}
303308
auto orgInput = static_cast<int>(op->get_input_size());
304309
if (op->get_type_name() == std::string("PagedAttentionExtension") &&
305310
orgInput == PagedAttentionExecutor::ID_SLIDING_WINDOW + 1) {

src/plugins/intel_cpu/src/nodes/rope.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
111111
jcp.dst_prc = precision_of<T>::value;
112112
jcp.rotary_ndims = config.rotary_ndims;
113113
jcp.interleave = false;
114-
jcp.cos_sin_ndims = config.cos_sin_ndims;
115114
m_rotaryKernel = createJitKernel(jcp);
116115
}
117116

@@ -154,12 +153,11 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
154153
t_sin = t_sin.reshape({1, t_sin.size(0), t_sin.size(1), t_sin.size(2)});
155154
}
156155

157-
const auto batch_size = t_src.size(0);
158-
const auto head_cnt = t_src.size(1);
159-
const auto seq_len = t_src.size(2);
160-
const auto feature_size = t_src.size(3);
161-
const auto half_rotary_dims = rotary_dims / 2;
162-
const size_t cos_sin_offset = (m_config.cos_sin_ndims == half_rotary_dims) ? 0 : half_rotary_dims;
156+
auto batch_size = t_src.size(0);
157+
auto head_cnt = t_src.size(1);
158+
auto seq_len = t_src.size(2);
159+
auto feature_size = t_src.size(3);
160+
163161
parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
164162
auto cos_pos = p;
165163
if (gather) {
@@ -177,12 +175,13 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
177175
if (m_rotaryKernel) {
178176
execJitKernel(m_rotaryKernel, src, dst, cos, sin);
179177
} else {
178+
auto half_rotary_dims = rotary_dims / 2;
180179
size_t i = 0;
181180
for (; i < half_rotary_dims; i++) {
182181
auto src0 = src[i];
183182
auto src1 = src[i + half_rotary_dims];
184183
dst[i] = cos[i] * src0 - sin[i] * src1;
185-
dst[i + half_rotary_dims] = cos[i + cos_sin_offset] * src1 + sin[i + cos_sin_offset] * src0;
184+
dst[i + half_rotary_dims] = cos[i + half_rotary_dims] * src1 + sin[i + half_rotary_dims] * src0;
186185
}
187186
}
188187
if (!can_inplace) {

src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,5 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwenVL,
7979
::testing::ValuesIn(vit_param)),
8080
RoPETestQwenVL::getTestCaseName);
8181

82-
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTOSS,
83-
RoPETestGPTOSS,
84-
::testing::Combine(
85-
::testing::Values(ov::element::f32),
86-
::testing::Values(ov::test::utils::DEVICE_CPU)),
87-
RoPETestGPTOSS::getTestCaseName);
88-
8982
} // namespace test
9083
} // namespace ov

src/plugins/intel_gpu/src/graph/impls/ocl_v2/rope_opt.cl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,14 +411,14 @@ uint cos_sin_p = p;
411411
output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2;
412412

413413
output[output_idx + HALF_ROTARY_NDIMS + r] =
414-
cos[cos_idx + COS_SIN_TABLE_OFFSET + r] * in2 + sin[sin_idx + COS_SIN_TABLE_OFFSET + r] * in1;
414+
cos[cos_idx + HALF_ROTARY_NDIMS + r] * in2 + sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1;
415415
#else
416416
INPUT_VEC_TYPE in1 = *(INPUT_VEC_TYPE*)(input + input_idx + r);
417417
INPUT_VEC_TYPE in2 = *(INPUT_VEC_TYPE*)(input + input_idx + HALF_ROTARY_NDIMS + r);
418418
INPUT_VEC_TYPE cos1 = *(INPUT_VEC_TYPE*)(cos + cos_idx + r);
419-
INPUT_VEC_TYPE cos2 = *(INPUT_VEC_TYPE*)(cos + cos_idx + COS_SIN_TABLE_OFFSET + r);
419+
INPUT_VEC_TYPE cos2 = *(INPUT_VEC_TYPE*)(cos + cos_idx + HALF_ROTARY_NDIMS + r);
420420
INPUT_VEC_TYPE sin1 = *(INPUT_VEC_TYPE*)(sin + sin_idx + r);
421-
INPUT_VEC_TYPE sin2 = *(INPUT_VEC_TYPE*)(sin + sin_idx + COS_SIN_TABLE_OFFSET + r);
421+
INPUT_VEC_TYPE sin2 = *(INPUT_VEC_TYPE*)(sin + sin_idx + HALF_ROTARY_NDIMS + r);
422422

423423
OUTPUT_VEC_TYPE out1 = cos1 * in1 - sin1 * in2;
424424
OUTPUT_VEC_TYPE out2 = cos2 * in2 + sin2 * in1;

src/plugins/intel_gpu/src/graph/impls/ocl_v2/rope_opt.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class RopeGenerator : public KernelGenerator {
5454
jit.make("HEAD_SIZE", desc->config.head_size);
5555
jit.make("ROTARY_NDIMS", desc->config.rotary_ndims);
5656
jit.make("HALF_ROTARY_NDIMS", desc->config.rotary_ndims / 2);
57-
jit.make("COS_SIN_TABLE_OFFSET", (desc->config.cos_sin_ndims == (desc->config.rotary_ndims / 2)) ? 0 : desc->config.rotary_ndims / 2);
5857
jit.make("HEAD_COUNT", desc->config.head_cnt);
5958

6059
if (desc->config.head_size > desc->config.rotary_ndims) {

0 commit comments

Comments
 (0)