Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ std::shared_ptr<pattern::op::Block> mlp3_no_bias_swiglu_block(
auto squeeze_Squeeze_1 =
wrap_type<ov::op::v0::Squeeze>({select_Gather_1, wrap_type<ov::op::v0::Constant>(pattern::value_matches("0"))});
// NonZero output_type relaxed to accept both i32 and i64
auto ListUnpack_NonZero_1 = wrap_type<ov::op::v3::NonZero>({squeeze_Squeeze_1});
auto ListUnpack_NonZero_1 = wrap_type<ov::op::v3::NonZero>({squeeze_Squeeze_1 | select_Gather_1});
auto ListUnpack_Split_1 = wrap_type<ov::op::v1::Split>(
{ListUnpack_NonZero_1, wrap_type<ov::op::v0::Constant>(pattern::value_matches("0"))},
{{"num_splits", 2}});
Expand Down Expand Up @@ -141,11 +141,11 @@ std::shared_ptr<pattern::op::Block> mlp3_no_bias_swiglu_block(
auto reshape_Reshape_1 =
wrap_type<ov::op::v1::Reshape>({reshape_Reshape_1_2, shape_const}, {{"special_zero", true}});
auto gate_proj_weight = pattern::any_input(pattern::rank_equals(2));
auto linear_MatMul_gate = wrap_type<ov::op::v0::MatMul>({reshape_Reshape_1, gate_proj_weight},
auto linear_MatMul_gate = wrap_type<ov::op::v0::MatMul>({reshape_Reshape_1 | reshape_Reshape_1_0, gate_proj_weight},
{{"transpose_a", false}, {"transpose_b", true}});
auto silu_Swish = wrap_type<ov::op::v4::Swish>({linear_MatMul_gate});
auto up_proj_weight = pattern::any_input(pattern::rank_equals(2));
auto linear_MatMul_up = wrap_type<ov::op::v0::MatMul>({reshape_Reshape_1, up_proj_weight},
auto linear_MatMul_up = wrap_type<ov::op::v0::MatMul>({reshape_Reshape_1 | reshape_Reshape_1_0, up_proj_weight},
{{"transpose_a", false}, {"transpose_b", true}});
auto mul_Multiply = wrap_type<ov::op::v1::Multiply>({silu_Swish, linear_MatMul_up}, {{"auto_broadcast", "numpy"}});
auto down_proj_weight = pattern::any_input(pattern::rank_equals(2));
Expand Down Expand Up @@ -216,10 +216,18 @@ std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> create_router_pattern()
auto one_hot_off = wrap_type<ov::op::v0::Constant>(pattern::value_matches("0"));
auto transpose_perm = wrap_type<ov::op::v0::Constant>(pattern::value_matches("2, 1, 0"));

auto softmax = wrap_type<ov::op::v8::Softmax>({linear_MatMul}, {{"axis", 1}});
auto topk = wrap_type<ov::op::v11::TopK>(
auto softmax_0 = wrap_type<ov::op::v8::Softmax>({linear_MatMul}, {{"axis", -1}});
auto softmax_1 = wrap_type<ov::op::v8::Softmax>({linear_MatMul}, {{"axis", 1}});
auto softmax = softmax_0 | softmax_1;
auto topk_none = wrap_type<ov::op::v11::TopK>(
{softmax, num_topk},
{{"axis", -1}, {"mode", "max"}, {"sort", "none"}, {"index_element_type", "i64"}, {"stable", false}});
topk_none->set_output_size(2);
auto topk_value = wrap_type<ov::op::v11::TopK>(
{softmax, num_topk},
{{"axis", -1}, {"mode", "max"}, {"sort", "value"}, {"index_element_type", "i64"}, {"stable", false}});
topk_value->set_output_size(2);
auto topk = topk_none | topk_value;
topk->set_output_size(2);
auto one_hot = wrap_type<ov::op::v1::OneHot>({topk->output(1), expert_num, one_hot_on, one_hot_off}, {{"axis", 2}});
auto permute = wrap_type<ov::op::v1::Transpose>({one_hot, transpose_perm});
Expand Down Expand Up @@ -253,7 +261,7 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> create_routing_weights_
auto sum_reduce = wrap_type<ov::op::v1::ReduceSum>({topk->output(0), reduce_neg1}, {{"keep_dims", true}});
auto normalized = wrap_type<ov::op::v1::Divide>({topk->output(0), sum_reduce},
{{"auto_broadcast", "numpy"}, {"m_pythondiv", true}});
auto unsqueeze = wrap_type<ov::op::v0::Unsqueeze>({normalized, axes.axis2});
auto unsqueeze = wrap_type<ov::op::v0::Unsqueeze>({normalized | topk->output(0), axes.axis2});
auto shape_of = wrap_type<ov::op::v3::ShapeOf>({unsqueeze}, {{"output_type", "i32"}});
auto split = wrap_type<ov::op::v1::Split>({shape_of, axes.axis0}, {{"num_splits", 3}});
split->set_output_size(3);
Expand Down Expand Up @@ -302,11 +310,13 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") {
auto num_last_add = matches.at(last_add).size();

// Collect expert data from all matched patterns
std::cout << "Fuse Moe Pattern For DeepSeek|" << num_last_add << std::endl;
std::vector<expert_data> all_experts;
all_experts.reserve(matches.at(expert_scatter).size());
for (const auto& pm : matches.at(expert_scatter)) {
auto slice_end_anchor = expert_scatter->get_anchor("slice_end_const", pm);
if (!slice_end_anchor.has_value() || !is_slice_to_end(slice_end_anchor.value().get_node_shared_ptr())) {
std::cout << "Fuse Moe Pattern For DeepSeek|" << "311" << std::endl;
return false;
}
auto gate_proj_node = expert_scatter->get_anchor("gate_proj_weight", pm).value().get_node_shared_ptr();
Expand All @@ -328,7 +338,7 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") {
for (const auto& expert : all_experts) {
experts_by_permute[expert.permute_node.get()].push_back(expert);
}

std::cout << "Fuse Moe Pattern For DeepSeek|" << "332" << std::endl;
// Create shared constants (used across all MoE layers)
auto const_0 = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
auto const_1 = ov::op::v0::Constant::create(element::i64, Shape{1}, {1});
Expand Down Expand Up @@ -515,7 +525,7 @@ ov::pass::FuseMOEExperts::FuseMOEExperts() : MultiMatcher("FuseMOEExperts") {

ov::replace_node(last_add_node, final_add);
}

std::cout << "Fuse Moe Pattern For DeepSeek|" << "519" << std::endl;
return true;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ ov::pass::FuseVectorizedMOE3GEMM::FuseVectorizedMOE3GEMM() {
moe->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), moe);
ov::replace_node(m.get_match_root(), moe);

register_new_node(moe);
return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MOECompressed : public ov::op::internal::MOE {
size_t has_batch_dim = 0;
bool has_zp = false;
ov::element::Type out_type = ov::element::dynamic;
bool top_k_reduce = true;
Config() = default;
Config(const MOE::Config& moe_config) : MOE::Config(moe_config) {}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,11 @@ class MoE3GemmSwigluSoftMaxTopK : public KernelGenerator {
auto jit = KernelGenerator::get_jit_constants(params);
auto desc = params.typed_desc<moe_3gemm_fused_compressed>();
jit.make("SOFTMAX_TOPK_ENABLE", 1);
if (desc->_config.top_k_reduce) {
jit.make("TOPK_REDUCE", 1);
} else {
jit.make("TOPK_REDUCE", 0);
}
jit.make("TOP_K", desc->_config.top_k);
jit.make("VALUE_NUM", desc->_config.num_expert);
jit.make("MOE_DTYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,15 @@ KERNEL(softmax_topk)(
local_output[0] = 1;
for(uint i = 1; i < TOP_K; i++) {
local_output[i] = native_exp(local_output[i] - max_v);
#if TOPK_REDUCE
softmax_total += local_output[i];
#endif
}
#if !(TOPK_REDUCE)
for(uint i = 1; i < VALUE_NUM; i++) {
softmax_total += native_exp(local_input[i] - max_v);
}
#endif
output_index += batch * TOP_K;
output += batch * TOP_K;

Expand Down Expand Up @@ -83,7 +90,6 @@ KERNEL (gather_2d_ref)(
dst_tok += k * HIDDEN_SIZE;

if (off >= HIDDEN_SIZE) {
printf("Warning off >= HIDDEN_SIZE: k = %d, off = %d, HIDDEN_SIZE = %d\n", k, off, HIDDEN_SIZE);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ ConvertMOEToMOECompressed::ConvertMOEToMOECompressed(bool is_pa) {

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

auto moe = ov::as_type_ptr<ov::op::internal::MOE>(pattern_map.at(moe_root).get_node_shared_ptr());
if (!moe || transformation_callback(moe)) {
return false;
Expand Down Expand Up @@ -297,7 +296,6 @@ ConvertMOEToMOECompressed::ConvertMOEToMOECompressed(bool is_pa) {
} else {
OPENVINO_THROW("Unsupported MOE expert type in ConvertMOEToMOECompressed");
}

return true;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ FuseMOE3GemmCompressed::FuseMOE3GemmCompressed() {
auto concat_m = wrap_type<ov::op::v0::Concat>({unsqueeze_m, unsqueeze_const_m}, consumers_count(1));
auto concat1_m = wrap_type<ov::op::v0::Concat>({unsqueeze_const_m, unsqueeze_m, any_input()}, consumers_count(1));
auto bc_m = wrap_type<ov::op::v3::Broadcast>({any_input(), concat_m}, consumers_count(1));
auto scatter_m = wrap_type<ov::op::v12::ScatterElementsUpdate>({bc_m->output(0), topk_m->output(1), norm_m->output(0), any_input()}, consumers_count(1));
auto topk_values = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{norm_m, topk_m});
auto scatter_m = wrap_type<ov::op::v12::ScatterElementsUpdate>({bc_m->output(0), topk_m->output(1), topk_values->output(0), any_input()}, consumers_count(1));
auto transpose_m = wrap_type<ov::op::v1::Transpose>({scatter_m, any_input()}, consumers_count(1));
auto reshape_m = wrap_type<ov::op::v1::Reshape>({transpose_m, concat1_m}, consumers_count(1));
auto unsqueeze_moe_m = wrap_type<ov::op::v0::Unsqueeze>({reshape_m, any_input()}, consumers_count(1));
Expand Down Expand Up @@ -86,7 +87,6 @@ FuseMOE3GemmCompressed::FuseMOE3GemmCompressed() {

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

auto moe_compressed = ov::as_type_ptr<ov::intel_gpu::op::MOECompressed>(pattern_map.at(moe_compressed_m).get_node_shared_ptr());
if (!moe_compressed || transformation_callback(moe_compressed)) {
return false;
Expand All @@ -103,12 +103,12 @@ FuseMOE3GemmCompressed::FuseMOE3GemmCompressed() {
args[8] = pattern_map.at(down_wei_m);
args[9] = pattern_map.at(down_scale_m);
args[10] = pattern_map.at(down_zp_m);

auto moe_3gemm_fused_compressed = std::make_shared<ov::intel_gpu::op::MOE3GemmFusedCompressed>(args, moe_compressed->get_config());
auto config = moe_compressed->get_config();
config.top_k_reduce = pattern_map.count(reduce_sum_m);
auto moe_3gemm_fused_compressed = std::make_shared<ov::intel_gpu::op::MOE3GemmFusedCompressed>(args, config);
moe_3gemm_fused_compressed->set_friendly_name(moe_compressed->get_friendly_name());
ov::copy_runtime_info(moe_compressed, moe_3gemm_fused_compressed);
ov::replace_node(moe_compressed, moe_3gemm_fused_compressed);

return true;
};

Expand Down
Loading