@@ -40,17 +40,21 @@ struct Conv1x1ToMatmulTestParams {
4040 bool with_bias;
4141 bool with_convert;
4242 bool with_param_weight;
43+ bool with_act_new_reshape;
4344 std::string activation_op_type;
4445};
4546
4647std::shared_ptr<ov::Model> gen_model (const Conv1x1ToMatmulTestParams& p) {
47- auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f16 , ov::Shape{1 , 1 , 2 , 10 });
48+ auto input = std::make_shared<ov::opset1::Parameter>(
49+ ov::element::f16 ,
50+ (p.activation_op_type == " Reshape" && p.with_act_new_reshape ) ? ov::Shape{1 , 1 , 2 , 5 } : ov::Shape{1 , 1 , 1 , 10 });
51+
4852 std::shared_ptr<ov::Node> act_node;
4953 if (p.activation_op_type == " Transpose" ) {
5054 auto transpose_const = ov::opset1::Constant::create (ov::element::i32 , ov::Shape{4 }, {0 , 3 , 1 , 2 });
5155 act_node = std::make_shared<ov::opset1::Transpose>(input, transpose_const);
5256 } else {
53- auto reshape_const = ov::opset1::Constant::create (ov::element::i32 , ov::Shape{4 }, {1 , 10 , 1 , 2 });
57+ auto reshape_const = ov::opset1::Constant::create (ov::element::i32 , ov::Shape{4 }, {1 , 10 , 1 , 1 });
5458 act_node = std::make_shared<ov::opset1::Reshape>(input, reshape_const, false );
5559 }
5660
@@ -114,15 +118,17 @@ std::shared_ptr<ov::Model> gen_model(const Conv1x1ToMatmulTestParams& p) {
114118 auto transpose_const = ov::opset1::Constant::create (ov::element::i32 , ov::Shape{4 }, {0 , 2 , 3 , 1 });
115119 out_node = std::make_shared<ov::opset1::Transpose>(current_node, transpose_const);
116120 } else {
117- auto reshape_const = ov::opset1::Constant::create (ov::element::i32 , ov::Shape{4 }, {1 , 1 , 2 , 15 });
121+ auto reshape_const = ov::opset1::Constant::create (ov::element::i32 , ov::Shape{4 }, {1 , 1 , 1 , 15 });
118122 out_node = std::make_shared<ov::opset1::Reshape>(current_node, reshape_const, false );
119123 }
120124
121125 return std::make_shared<ov::Model>(ov::OutputVector{out_node}, params);
122126}
123127
124128std::shared_ptr<ov::Model> gen_model_ref (const Conv1x1ToMatmulTestParams& p) {
125- auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f16 , ov::Shape{1 , 1 , 2 , 10 });
129+ auto input = std::make_shared<ov::opset1::Parameter>(
130+ ov::element::f16 ,
131+ (p.activation_op_type == " Reshape" && p.with_act_new_reshape ) ? ov::Shape{1 , 1 , 2 , 5 } : ov::Shape{1 , 1 , 1 , 10 });
126132
127133 std::shared_ptr<ov::Node> weights_node;
128134 ov::ParameterVector params = {input};
@@ -162,7 +168,12 @@ std::shared_ptr<ov::Model> gen_model_ref(const Conv1x1ToMatmulTestParams& p) {
162168 mul = std::make_shared<ov::opset1::Reshape>(mul, reshape_const, false );
163169 }
164170
165- auto matmul = std::make_shared<ov::op::v0::MatMul>(input, mul, false , true );
171+ std::shared_ptr<ov::Node> act_node = input;
172+ if (p.activation_op_type == " Reshape" && p.with_act_new_reshape ) {
173+ auto reshape_const = ov::opset1::Constant::create (ov::element::i64 , ov::Shape{4 }, {1 , 1 , 1 , 10 });
174+ act_node = std::make_shared<ov::opset1::Reshape>(input, reshape_const, false );
175+ }
176+ auto matmul = std::make_shared<ov::op::v0::MatMul>(act_node, mul, false , true );
166177 current_node = matmul;
167178
168179 if (p.with_bias ) {
@@ -175,7 +186,7 @@ std::shared_ptr<ov::Model> gen_model_ref(const Conv1x1ToMatmulTestParams& p) {
175186
176187 std::shared_ptr<ov::Node> out_node;
177188 if (p.activation_op_type == " Reshape" ) {
178- auto reshape_const = ov::opset1::Constant::create (ov::element::i32 , ov::Shape{4 }, {1 , 1 , 2 , 15 });
189+ auto reshape_const = ov::opset1::Constant::create (ov::element::i32 , ov::Shape{4 }, {1 , 1 , 1 , 15 });
179190 out_node = std::make_shared<ov::opset1::Reshape>(current_node, reshape_const, false );
180191 } else {
181192 out_node = current_node;
@@ -187,33 +198,45 @@ std::shared_ptr<ov::Model> gen_model_ref(const Conv1x1ToMatmulTestParams& p) {
187198
188199class ConvertWeightCompressedConv1x1ToMatmulTest
189200 : public TransformationTestsF,
190- public WithParamInterface<std::tuple<bool , bool , bool , bool , bool , std::string>> {
201+ public WithParamInterface<std::tuple<bool , bool , bool , bool , bool , bool , std::string>> {
191202public:
192203 static std::string get_test_case_name (
193- const testing::TestParamInfo<std::tuple<bool , bool , bool , bool , bool , std::string>>& obj) {
194- const auto & [with_group_quant, with_zp, with_bias, with_convert, with_param_weight, activation_op_type] =
195- obj.param ;
204+ const testing::TestParamInfo<std::tuple<bool , bool , bool , bool , bool , bool , std::string>>& obj) {
205+ const auto & [with_group_quant,
206+ with_zp,
207+ with_bias,
208+ with_convert,
209+ with_param_weight,
210+ with_act_new_reshape,
211+ activation_op_type] = obj.param ;
196212
197213 std::ostringstream result;
198214 result << " with_group_quant=" << with_group_quant << " _" ;
199215 result << " with_zp=" << with_zp << " _" ;
200216 result << " with_bias=" << with_bias << " _" ;
201217 result << " with_convert=" << with_convert << " _" ;
202218 result << " with_param_weight=" << with_param_weight << " _" ;
219+ result << " with_act_new_reshape=" << with_act_new_reshape << " _" ;
203220 result << " activation_op_type=" << activation_op_type;
204221 return result.str ();
205222 }
206223
207224protected:
208225 void SetUp () override {
209226 TransformationTestsF::SetUp ();
210- const auto & [with_group_quant, with_zp, with_bias, with_convert, with_param_weight, activation_op_type] =
211- GetParam ();
227+ const auto & [with_group_quant,
228+ with_zp,
229+ with_bias,
230+ with_convert,
231+ with_param_weight,
232+ with_act_new_reshape,
233+ activation_op_type] = GetParam ();
212234 Conv1x1ToMatmulTestParams params{with_group_quant,
213235 with_zp,
214236 with_bias,
215237 with_convert,
216238 with_param_weight,
239+ with_act_new_reshape,
217240 activation_op_type};
218241 model = gen_model (params);
219242 model_ref = gen_model_ref (params);
@@ -230,6 +253,7 @@ INSTANTIATE_TEST_SUITE_P(TransformationTests,
230253 ::testing::Bool(),
231254 ::testing::Bool(),
232255 ::testing::Bool(),
256+ ::testing::Bool(),
233257 ::testing::Values(" Transpose" , " Reshape" )),
234258 ConvertWeightCompressedConv1x1ToMatmulTest::get_test_case_name);
235259
0 commit comments