Skip to content

Commit eb718d6

Browse files
committed
add reshape_activations_m to convert_weight_compressed_conv1x1_to_matmul pattern to resolve activation reshape from [1,1,num_head,head_dim] to [1,hidden_in,1,1] situation
1 parent b7920af commit eb718d6

File tree

2 files changed

+59
-13
lines changed

2 files changed

+59
-13
lines changed

src/common/transformations/src/transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ ov::pass::ConvertWeightCompressedConv1x1ToMatmul::ConvertWeightCompressedConv1x1
4040
auto first_input_m = ov::pass::pattern::any_input();
4141
auto a_order_m = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
4242
auto transpose_activations_m = ov::pass::pattern::wrap_type<ov::op::v1::Transpose>({first_input_m, a_order_m});
43-
auto reshape_activations_m = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({first_input_m, a_order_m});
43+
auto reshape_activations_m =
44+
ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({first_input_m, a_order_m},
45+
pattern::shape_matches("[?, hidden_in, 1, 1]"));
4446
auto a_m =
4547
std::make_shared<ov::pass::pattern::op::Or>(OutputVector{transpose_activations_m, reshape_activations_m});
4648

@@ -211,6 +213,26 @@ ov::pass::ConvertWeightCompressedConv1x1ToMatmul::ConvertWeightCompressedConv1x1
211213
scaled_weight = final_weight_reshape;
212214
}
213215

216+
// When activation is reshaped to [?, hidden_in, 1, 1], two possible cases:
217+
// 1. reshape from [..., hidden_in]
218+
// direct use it in matmul.
219+
// 2. reshape from [..., num_head, head_dim]
220+
// can't use it directly, need reshape it to [..., hidden_in], then use in matmul.
221+
if (pattern_map.count(reshape_activations_m)) {
222+
auto reshape_activations = pattern_map.at(reshape_activations_m).get_node_shared_ptr();
223+
auto shape_in = reshape_activations->get_input_partial_shape(0);
224+
auto shape_out = reshape_activations->get_output_partial_shape(0);
225+
if (shape_in[-1].is_dynamic() || shape_in[-1].get_length() != shape_out[1].get_length()) {
226+
auto reshape_const =
227+
std::make_shared<ov::op::v0::Constant>(ov::element::i64,
228+
ov::Shape{4},
229+
std::vector<int64_t>{1, 1, -1, shape_out[1].get_length()});
230+
auto reshape_activations_new = std::make_shared<ov::op::v1::Reshape>(activation, reshape_const, false);
231+
ov::copy_runtime_info(reshape_activations, reshape_activations_new);
232+
activation = reshape_activations_new;
233+
}
234+
}
235+
214236
auto matmul = std::make_shared<ov::op::v0::MatMul>(activation, scaled_weight, false, true);
215237
ov::copy_runtime_info(conv1x1, matmul);
216238
std::shared_ptr<Node> matmul_out;

src/common/transformations/tests/op_conversions/convert_weight_compressed_conv1x1_to_matmul_test.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,21 @@ struct Conv1x1ToMatmulTestParams {
4040
bool with_bias;
4141
bool with_convert;
4242
bool with_param_weight;
43+
bool with_act_new_reshape;
4344
std::string activation_op_type;
4445
};
4546

4647
std::shared_ptr<ov::Model> gen_model(const Conv1x1ToMatmulTestParams& p) {
47-
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::Shape{1, 1, 2, 10});
48+
auto input = std::make_shared<ov::opset1::Parameter>(
49+
ov::element::f16,
50+
(p.activation_op_type == "Reshape" && p.with_act_new_reshape) ? ov::Shape{1, 1, 2, 5} : ov::Shape{1, 1, 1, 10});
51+
4852
std::shared_ptr<ov::Node> act_node;
4953
if (p.activation_op_type == "Transpose") {
5054
auto transpose_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 3, 1, 2});
5155
act_node = std::make_shared<ov::opset1::Transpose>(input, transpose_const);
5256
} else {
53-
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {1, 10, 1, 2});
57+
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {1, 10, 1, 1});
5458
act_node = std::make_shared<ov::opset1::Reshape>(input, reshape_const, false);
5559
}
5660

@@ -114,15 +118,17 @@ std::shared_ptr<ov::Model> gen_model(const Conv1x1ToMatmulTestParams& p) {
114118
auto transpose_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 2, 3, 1});
115119
out_node = std::make_shared<ov::opset1::Transpose>(current_node, transpose_const);
116120
} else {
117-
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {1, 1, 2, 15});
121+
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {1, 1, 1, 15});
118122
out_node = std::make_shared<ov::opset1::Reshape>(current_node, reshape_const, false);
119123
}
120124

121125
return std::make_shared<ov::Model>(ov::OutputVector{out_node}, params);
122126
}
123127

124128
std::shared_ptr<ov::Model> gen_model_ref(const Conv1x1ToMatmulTestParams& p) {
125-
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::Shape{1, 1, 2, 10});
129+
auto input = std::make_shared<ov::opset1::Parameter>(
130+
ov::element::f16,
131+
(p.activation_op_type == "Reshape" && p.with_act_new_reshape) ? ov::Shape{1, 1, 2, 5} : ov::Shape{1, 1, 1, 10});
126132

127133
std::shared_ptr<ov::Node> weights_node;
128134
ov::ParameterVector params = {input};
@@ -162,7 +168,12 @@ std::shared_ptr<ov::Model> gen_model_ref(const Conv1x1ToMatmulTestParams& p) {
162168
mul = std::make_shared<ov::opset1::Reshape>(mul, reshape_const, false);
163169
}
164170

165-
auto matmul = std::make_shared<ov::op::v0::MatMul>(input, mul, false, true);
171+
std::shared_ptr<ov::Node> act_node = input;
172+
if (p.activation_op_type == "Reshape" && p.with_act_new_reshape) {
173+
auto reshape_const = ov::opset1::Constant::create(ov::element::i64, ov::Shape{4}, {1, 1, 1, 10});
174+
act_node = std::make_shared<ov::opset1::Reshape>(input, reshape_const, false);
175+
}
176+
auto matmul = std::make_shared<ov::op::v0::MatMul>(act_node, mul, false, true);
166177
current_node = matmul;
167178

168179
if (p.with_bias) {
@@ -175,7 +186,7 @@ std::shared_ptr<ov::Model> gen_model_ref(const Conv1x1ToMatmulTestParams& p) {
175186

176187
std::shared_ptr<ov::Node> out_node;
177188
if (p.activation_op_type == "Reshape") {
178-
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {1, 1, 2, 15});
189+
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {1, 1, 1, 15});
179190
out_node = std::make_shared<ov::opset1::Reshape>(current_node, reshape_const, false);
180191
} else {
181192
out_node = current_node;
@@ -187,33 +198,45 @@ std::shared_ptr<ov::Model> gen_model_ref(const Conv1x1ToMatmulTestParams& p) {
187198

188199
class ConvertWeightCompressedConv1x1ToMatmulTest
189200
: public TransformationTestsF,
190-
public WithParamInterface<std::tuple<bool, bool, bool, bool, bool, std::string>> {
201+
public WithParamInterface<std::tuple<bool, bool, bool, bool, bool, bool, std::string>> {
191202
public:
192203
static std::string get_test_case_name(
193-
const testing::TestParamInfo<std::tuple<bool, bool, bool, bool, bool, std::string>>& obj) {
194-
const auto& [with_group_quant, with_zp, with_bias, with_convert, with_param_weight, activation_op_type] =
195-
obj.param;
204+
const testing::TestParamInfo<std::tuple<bool, bool, bool, bool, bool, bool, std::string>>& obj) {
205+
const auto& [with_group_quant,
206+
with_zp,
207+
with_bias,
208+
with_convert,
209+
with_param_weight,
210+
with_act_new_reshape,
211+
activation_op_type] = obj.param;
196212

197213
std::ostringstream result;
198214
result << "with_group_quant=" << with_group_quant << "_";
199215
result << "with_zp=" << with_zp << "_";
200216
result << "with_bias=" << with_bias << "_";
201217
result << "with_convert=" << with_convert << "_";
202218
result << "with_param_weight=" << with_param_weight << "_";
219+
result << "with_act_new_reshape=" << with_act_new_reshape << "_";
203220
result << "activation_op_type=" << activation_op_type;
204221
return result.str();
205222
}
206223

207224
protected:
208225
void SetUp() override {
209226
TransformationTestsF::SetUp();
210-
const auto& [with_group_quant, with_zp, with_bias, with_convert, with_param_weight, activation_op_type] =
211-
GetParam();
227+
const auto& [with_group_quant,
228+
with_zp,
229+
with_bias,
230+
with_convert,
231+
with_param_weight,
232+
with_act_new_reshape,
233+
activation_op_type] = GetParam();
212234
Conv1x1ToMatmulTestParams params{with_group_quant,
213235
with_zp,
214236
with_bias,
215237
with_convert,
216238
with_param_weight,
239+
with_act_new_reshape,
217240
activation_op_type};
218241
model = gen_model(params);
219242
model_ref = gen_model_ref(params);
@@ -230,6 +253,7 @@ INSTANTIATE_TEST_SUITE_P(TransformationTests,
230253
::testing::Bool(),
231254
::testing::Bool(),
232255
::testing::Bool(),
256+
::testing::Bool(),
233257
::testing::Values("Transpose", "Reshape")),
234258
ConvertWeightCompressedConv1x1ToMatmulTest::get_test_case_name);
235259

0 commit comments

Comments
 (0)