@@ -124,6 +124,64 @@ TEST_F(TransformationTestsF, TranposeMatmulFusion4) {
124124 }
125125}
126126
127+ TEST_F (TransformationTestsF, TranposeMatmulFusion5) {
128+ {
129+ auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::PartialShape::dynamic (3 ));
130+ auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::PartialShape::dynamic (3 ));
131+ auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
132+ auto tranpose_c_const = ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{3 }, {0 , 2 , 1 });
133+ auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);
134+
135+ model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });
136+
137+ const auto supports_immad = false ;
138+ manager.register_pass <TransposeFusion>(supports_immad);
139+ }
140+ {
141+ std::vector<int64_t > order_a = {0 , 1 , 2 };
142+ std::vector<int64_t > order_b = {0 , 1 , 2 };
143+ std::vector<int64_t > order_c = {0 , 2 , 1 };
144+ auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::PartialShape::dynamic (3 ));
145+ auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::PartialShape::dynamic (3 ));
146+ auto gemm = std::make_shared<ov::intel_gpu::op::Gemm>(input_a, input_b, order_a, order_b, order_c, ov::element::undefined);
147+
148+ model_ref = std::make_shared<ov::Model>(ov::NodeVector{ gemm }, ov::ParameterVector{ input_a, input_b });
149+ comparator.enable (FunctionsComparator::ATTRIBUTES);
150+ }
151+ }
152+
153+ TEST_F (TransformationTestsF, TranposeMatmulFusion6) {
154+ auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::PartialShape::dynamic (2 ));
155+ auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::PartialShape::dynamic (2 ));
156+ auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
157+ auto tranpose_c_const = ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{2 }, {1 , 0 });
158+ auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);
159+
160+ model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });
161+
162+ const auto supports_immad = false ;
163+ manager.register_pass <TransposeFusion>(supports_immad);
164+
165+ model_ref = model->clone ();
166+ comparator.enable (FunctionsComparator::ATTRIBUTES);
167+ }
168+
169+ TEST_F (TransformationTestsF, TranposeMatmulFusion7) {
170+ auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::PartialShape{2 , 4 });
171+ auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::PartialShape{4 , 2 });
172+ auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
173+ auto tranpose_c_const = ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{2 }, {1 , 0 });
174+ auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);
175+
176+ model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });
177+
178+ const auto supports_immad = false ;
179+ manager.register_pass <TransposeFusion>(supports_immad);
180+
181+ model_ref = model->clone ();
182+ comparator.enable (FunctionsComparator::ATTRIBUTES);
183+ }
184+
127185TEST_F (TransformationTestsF, TranposeMatmulFusion_Illegal_1) {
128186 {
129187 auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32 , ov::PartialShape{10 , 20 });
0 commit comments