Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ RoPEFusionChatGLM::RoPEFusionChatGLM(const bool support_2d_rope) {
static std::shared_ptr<ov::Node> build_ChatGLMHF_interleave_pattern(std::shared_ptr<ov::Node> cos_or_sin) {
auto transpose = pattern::wrap_type<v1::Transpose>({cos_or_sin, pattern::any_input()});
auto reshape = pattern::wrap_type<v1::Reshape>({transpose, pattern::any_input()});
auto multiply = pattern::wrap_type<v1::Multiply>({reshape, pattern::any_input()});
auto multiply = pattern::wrap_type<v1::Multiply, ov::op::util::BroadcastBase>({reshape, pattern::any_input()});
auto gather_nd = pattern::wrap_type<v8::GatherND>({multiply, pattern::any_input()});
auto transpose_1 = pattern::wrap_type<v1::Transpose>({gather_nd, pattern::any_input()});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/core/node_vector.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/gather_elements.hpp"
Expand Down Expand Up @@ -1140,7 +1141,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM_nano_2d_rope) {
}
}

TEST_F(TransformationTestsF, ConvertToROPE_chatGLMHF_2d_rope_GatherND) {
TEST_F(TransformationTestsF, ConvertToROPE_chatGLMHF_2d_rope_GatherND_CPU) {
disable_rt_info_check();
const int seq_len = 7;
const int num_heads = 32;
Expand Down Expand Up @@ -1215,6 +1216,108 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLMHF_2d_rope_GatherND) {
}
}

TEST_F(TransformationTestsF, ConvertToROPE_chatGLMHF_2d_rope_GatherND_GPU) {
disable_rt_info_check();
const int seq_len = 7;
const int num_heads = 32;
const int ndims = 128;
const int rotary_ndims = 64;
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{seq_len, 1, 4096});
auto Reshape = makeOP<ov::op::v1::Reshape>({input, {-1, 32, 1, 128}}, {{"special_zero", false}});
auto strided_slice = makeOP<v1::StridedSlice>({Reshape, {0, 0, 0, 0}, {0, 0, 0, 64}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});

auto sin = std::make_shared<ov::opset1::Parameter>(ov::element::f16,
ov::PartialShape{seq_len, 1, 1, (rotary_ndims / 2)});
auto TransposeSin = makeOP<ov::op::v1::Transpose>({sin, {3, 1, 2, 0}});
auto ReshapeSin = makeOP<ov::op::v1::Reshape>({TransposeSin, {32, 1, 1, 1, -1}}, {{"special_zero", false}});
auto BroadcastSinParam = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{5});
auto BroadcastSin = makeOP<ov::op::v3::Broadcast>({ReshapeSin, BroadcastSinParam}, {{"mode", "bidirectional"}});
auto GatherNDSinConstant = makeConst(ov::element::i32, ov::Shape({64, 2}), MOCK_VALUE);
auto GatherNDSin = makeOP<ov::op::v8::GatherND>({BroadcastSin, GatherNDSinConstant}, {{"batch_dims", 0}});
auto TransposeSin0 = makeOP<ov::op::v1::Transpose>({GatherNDSin, {3, 1, 2, 0}});

auto cos = std::make_shared<ov::opset1::Parameter>(ov::element::f16,
ov::PartialShape{seq_len, 1, 1, (rotary_ndims / 2)});
auto TransposeCos = makeOP<ov::op::v1::Transpose>({cos, {3, 1, 2, 0}});
auto ReshapeCos = makeOP<ov::op::v1::Reshape>({TransposeCos, {32, 1, 1, 1, -1}}, {{"special_zero", false}});
auto BroadcastCosParam = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{5});
auto BroadcastCos = makeOP<ov::op::v3::Broadcast>({ReshapeCos, BroadcastCosParam}, {{"mode", "bidirectional"}});
auto GatherNDCosConstant = makeConst(ov::element::i32, ov::Shape({64, 2}), MOCK_VALUE);
auto GatherNDCos = makeOP<ov::op::v8::GatherND>({BroadcastCos, GatherNDCosConstant}, {{"batch_dims", 0}});
auto TransposeCos0 = makeOP<ov::op::v1::Transpose>({GatherNDCos, {3, 1, 2, 0}});

auto Strided_slice0 = makeOP<v1::StridedSlice>({strided_slice, {0, 0, 0, 1}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto Constant_75741 = makeConst(ov::element::f16,
ov::Shape({
1,
1,
1,
1,
}),
{-1});
auto Neg_multiply = makeOP<v1::Multiply>({Strided_slice0, Constant_75741}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze = makeOP<v1::Reshape>({Neg_multiply, {-1, 32, 1, 32, 1}}, {{"special_zero", false}});
auto Strided_slice1 = makeOP<v1::StridedSlice>({strided_slice, {0, 0, 0, 0}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto Unsqueeze1 = makeOP<ov::op::v1::Reshape>({Strided_slice1, {-1, 32, 1, 32, 1}}, {{"special_zero", false}});
auto Stack_reshape = makeOP<ov::op::v0::Concat>({Unsqueeze, Unsqueeze1}, {{"axis", -1}});
auto Flatten_reshape = makeOP<ov::op::v1::Reshape>({Stack_reshape, {0, 32, 0, 64}}, {{"special_zero", true}});
auto Multiply1 = makeOP<ov::op::v1::Multiply>({Flatten_reshape, TransposeSin0}, {{"auto_broadcast", "numpy"}});

auto Multiply2 = makeOP<ov::op::v1::Multiply>({strided_slice, TransposeCos0}, {{"auto_broadcast", "numpy"}});
auto Add = makeOP<ov::op::v1::Add>({Multiply2, Multiply1}, {{"auto_broadcast", "numpy"}});

auto Strided_slice2 = makeOP<v1::StridedSlice>({Reshape, {0, 0, 0, 64}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto Concat = makeOP<v0::Concat>({Add, Strided_slice2}, {{"axis", -1}});
model = std::make_shared<ov::Model>(ov::OutputVector{Concat},
ov::ParameterVector{input, cos, sin, BroadcastSinParam, BroadcastCosParam});
}
manager.register_pass<ov::pass::RoPEFusion>(true);
{
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{seq_len, 1, 4096});
auto cos = std::make_shared<ov::opset1::Parameter>(ov::element::f16,
ov::PartialShape{seq_len, 1, 1, (rotary_ndims / 2)});
auto sin = std::make_shared<ov::opset1::Parameter>(ov::element::f16,
ov::PartialShape{seq_len, 1, 1, (rotary_ndims / 2)});
auto rope = makeOP<ov::op::internal::RoPE>({input, cos, sin},
{{"config.slice_start", 0},
{"config.slice_stop", 0},
{"config.input_trans0213", false},
{"config.output_trans0213", false},
{"config.is_interleaved", false},
{"config.rotary_ndims", rotary_ndims},
{"config.is_chatglm", true},
{"config.support_2d_rope", true},
{"config.support_3d_rope", false},
{"config.is_qwen", false},
{"config.use_rope_cache", false},
{"config.head_cnt", num_heads},
{"config.head_size", ndims},
{"config.gather_position_arg_id", 0}});
model_ref = std::make_shared<ov::Model>(ov::OutputVector{rope}, ov::ParameterVector{input, cos, sin});
}
}

TEST_F(TransformationTestsF, ConvertToROPE_chatGLMHF_2d_rope) {
disable_rt_info_check();
const int seq_len = 7;
Expand Down
Loading