Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
aa10cbe
add perf
luo-cheng2021 Feb 11, 2025
3a80081
perf for sdpa/pa
luo-cheng2021 Feb 13, 2025
d209f03
add git ignore
luo-cheng2021 Mar 26, 2025
45bf16e
insert if for moe expert
luo-cheng2021 Mar 31, 2025
795e323
add moeexpert support
luo-cheng2021 Apr 2, 2025
d7f2602
fix moexpert precision is always f32
luo-cheng2021 Apr 2, 2025
7f4b901
add moeexpert support for gpu
luo-cheng2021 Apr 2, 2025
76a7d5b
opt: 1, simplify subgraph inside moeexpert; 2, only compute skip flag…
luo-cheng2021 Apr 4, 2025
6079269
opt: remove nonzero->split from subgraph into moeexpert for gpu
luo-cheng2021 Apr 5, 2025
c385c8f
Support Qwen3 rms kernel for input with dynamic padding
riverlijunjie Mar 31, 2025
ada754a
Add test case
riverlijunjie Mar 31, 2025
38ded44
WA: moe_expert wait all inputs ready
luo-cheng2021 Apr 7, 2025
f84303e
fix incorrect output shape computation
luo-cheng2021 Apr 8, 2025
df0ca20
add fast path for expert mask computation if no padding
luo-cheng2021 Apr 8, 2025
00e7d9a
qwen3 moe compile model opt, from 150s to 70s in LNL (#66)
riverlijunjie Apr 9, 2025
019262b
Move FuseMoeExpert2 ahead of CommonOptimizations to decrease compilin…
ceciliapeng2011 Apr 9, 2025
994c094
not use subgraph for moeexpert
luo-cheng2021 Apr 10, 2025
22c93ee
fix scale/zp layout; first expert should not be inplace
luo-cheng2021 Apr 11, 2025
7c872a0
merge all experts into one op
luo-cheng2021 Apr 12, 2025
75b9683
Optimize gather and index_add performance
riverlijunjie Apr 13, 2025
2d8eb4e
fix out_of_resource error on lunarlake
luo-cheng2021 Apr 14, 2025
06e436c
Move weigts from usm_host to usm_device memory
riverlijunjie Apr 16, 2025
02f2331
Add ITT for MoE
riverlijunjie Apr 17, 2025
b6b5f1d
Optimize BMG first token due to index_add kernel
riverlijunjie Apr 17, 2025
9383141
opt: merge all experts into one for batch1
luo-cheng2021 Apr 18, 2025
c7ef4ea
opt: cl code for mlp_*
luo-cheng2021 Apr 18, 2025
76e6ed7
change weight back to ba
luo-cheng2021 Apr 19, 2025
8471f6b
small tune for lunarlake
luo-cheng2021 Apr 21, 2025
818ba1b
fuse onehot into moe
luo-cheng2021 Apr 21, 2025
eed40eb
not wait gpu for batch1
luo-cheng2021 Apr 21, 2025
bd8e5f6
optimize mlp 2nd token bandwidth
usstq Apr 22, 2025
b7278d9
minor fix
luo-cheng2021 Apr 22, 2025
caa1f6e
Optimize moe_reduce for BMG
riverlijunjie Apr 24, 2025
e2812a4
add cm support
luo-cheng2021 Apr 23, 2025
0d2e996
moe expert cm kernel
luo-cheng2021 Apr 24, 2025
e08f8af
moe cm group 128 ok
luo-cheng2021 Apr 25, 2025
45618fb
cm moe zp ok(env: CM_MASK=3)
luo-cheng2021 Apr 25, 2025
27cbccf
default enable moe_up, disable moe_down
luo-cheng2021 Apr 25, 2025
100ca20
minor: reduce the parameter number of cm moe_up
luo-cheng2021 Apr 25, 2025
a8bcc42
Add perf
riverlijunjie Apr 24, 2025
e79494d
use i32 for paged_attention
luo-cheng2021 Apr 27, 2025
c409cfb
fuse softmax-topk
peterchen-intel Apr 27, 2025
7a3d3e4
fix bugs in softmax_topk fusion (add env NO_SOFTTOPK)
peterchen-intel Apr 28, 2025
3bca3f3
not alloc mem for cm if CM_MASK==0
luo-cheng2021 Apr 28, 2025
7930a66
Disable perf by default
riverlijunjie Apr 29, 2025
8e67325
Remove some logs
riverlijunjie May 6, 2025
378d56e
Remove cpu moe code
riverlijunjie May 6, 2025
548f1c2
Merge branch 'master' into gpu/qwen3_moe_cm
riverlijunjie May 6, 2025
7c06a5d
revert use i32 for pa
riverlijunjie May 7, 2025
c252879
Cleanup unused code
riverlijunjie May 7, 2025
3aa1d73
Move test case to common
riverlijunjie May 7, 2025
0722290
refine moe_expert_opt
riverlijunjie May 7, 2025
4b10730
simplify transform part code(TODO: accuracy)
luo-cheng2021 May 7, 2025
bf32356
fix gpu unit test(use cpu as reference)
luo-cheng2021 May 8, 2025
21efc98
Solve usm indirect memory access
riverlijunjie May 8, 2025
c7f5dab
more checks for pattern match
luo-cheng2021 May 9, 2025
a9c5a4f
use framework's intermediate buffer mechanism
luo-cheng2021 May 9, 2025
6a2dd26
Use direct memory access replace indirect access mode
riverlijunjie May 9, 2025
f717543
update cm kernel
riverlijunjie May 10, 2025
b50d233
update for CM_MASK
riverlijunjie May 11, 2025
a72b880
remove perf tool
riverlijunjie May 11, 2025
851babf
minor update
riverlijunjie May 11, 2025
210f1d4
add cache for moe_expert(temporarily disabe cm due to additional scal…
luo-cheng2021 May 12, 2025
b86c8f3
minor cleanup
luo-cheng2021 May 12, 2025
b9e91da
revert optimization for shared_ops_optimization
riverlijunjie May 12, 2025
a011264
fix ci error
luo-cheng2021 May 12, 2025
0d9354f
fix CI failure
luo-cheng2021 May 12, 2025
8d1f344
Move cm kernel to cm directory
riverlijunjie May 19, 2025
6df53a1
Merge branch 'master' into gpu/qwen3_moe_cm
riverlijunjie May 20, 2025
ff0c74f
apply review comments
luo-cheng2021 May 20, 2025
18312f2
Some reviewer comments
riverlijunjie May 20, 2025
e06f8de
Remove CM kernel and move to the following PR
riverlijunjie May 21, 2025
f43d457
apply review comments
luo-cheng2021 May 21, 2025
0c90c5a
Merge branch 'master' into gpu/qwen3_moe_cm
peterchen-intel May 25, 2025
255992b
Merge remote-tracking branch 'upstream/master' into gpu/qwen2_moe_cm
luo-cheng2021 Jun 5, 2025
4662b8f
apply review comments
luo-cheng2021 Jun 5, 2025
0d0ecca
Fix CI issues
riverlijunjie Jun 6, 2025
41fa911
transformation support more weight datatype
luo-cheng2021 Jun 17, 2025
e7e909a
Merge remote-tracking branch 'upstream/master' into gpu/qwen3_moe_cm
luo-cheng2021 Jun 17, 2025
73071de
apply review comments
luo-cheng2021 Jun 23, 2025
2181617
Merge remote-tracking branch 'upstream/master' into gpu/qwen3_moe_cm
luo-cheng2021 Jun 23, 2025
ed49b4e
fuse router+expert0 first to avoid fusing expert success but router f…
luo-cheng2021 Jun 25, 2025
ad36133
Merge branch 'master' into gpu/qwen3_moe_cm
peterchen-intel Jun 30, 2025
37ca973
solve merge conflicts
luo-cheng2021 Jul 1, 2025
cba1763
Merge branch 'main' into gpu/qwen3_moe_cm
riverlijunjie Jul 11, 2025
1ed2a11
Handle reviewer's comments
riverlijunjie Jul 11, 2025
4807bfd
Merge branch 'releases/2025/3' into gpu/qwen3_moe_cm
WeldonWangwang Sep 11, 2025
9b982af
Fix build error
WeldonWangwang Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
.. {#openvino_docs_ops_internal_MOE}

MOE
===


.. meta::
:description: Learn about MOE - a basic block for the mixture of experts.

**Versioned name**: *MOE*

**Category**: *Sequence processing*

**Short description**: *MOE* partially implements
`Qwen3MoeSparseMoeBlock.forward <https://github.com/huggingface/transformers/blob/1fed6166c00b800330fcda8494f78cbcad8e4e3b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L235-L263>`__,
omitting the `gate` operation.

**Detailed description**:

*MOE* provides functionality according to the following pseudo-code using torch:

.. code-block:: py
:force:

def MOE(hidden_states, router_logits, attrs):
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, attrs.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=attrs.expert_num).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(attrs.expert_num):
expert_layer = attrs.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states


**Attributes**

* *topk*

* **Description**: The number of activated expert. Must be less than or equal to ``expert_num``.
* **Range of values**: a positive integer number
* **Type**: ``size_t``
* **Required**: *yes*

* *expert_num*

* **Description**: The number of expert number.
* **Range of values**: a positive integer number
* **Type**: ``size_t``
* **Required**: *yes*

* *hidden_size*

* **Description**: Feature size which is extracted from ``hidden_states``.
* **Range of values**: a positive integer number
* **Type**: ``size_t``
* **Required**: *yes*

* *intermediate_size*

* **Description**: Intermediate size which is extracted from expert_layer mentioned in the pseudo-code.
* **Range of values**: a positive integer number
* **Type**: ``size_t``
* **Required**: *yes*

* *group_size*

* **Description**: Weight compression group size which is extracted from expert_layer mentioned in the pseudo-code.
* **Range of values**: a greater than or equal to 0 integer number
* **Type**: ``size_t``
* **Required**: *no*

* *weight_type*

* **Description**: Weight data type which are extracted from expert_layer mentioned in the pseudo-code.
* **Range of values**: "f16", "f32", "u8", "u4"
* **Required**: *yes*

* *scale_type*

* **Description**: Scale data type which are extracted from expert_layer mentioned in the pseudo-code.
* **Range of values**: "f16", "dynamic"
* **Required**: *no*

* *zp_type*

* **Description**: Zero point data type which are extracted from expert_layer mentioned in the pseudo-code.
* **Range of values**: "u8", "u4", "dynamic"
* **Required**: *no*

* *gates/ups/downs*

* **Description**: Weight data which are extracted from expert_layer mentioned in the pseudo-code.
* **Type**: ``v0::Constant``
* **Required**: *yes*

**Inputs**

* **1**: ``hidden_states`` - 2 dimensional tensor of type *T* with the shape [batch, hidden_size]. **Required.**

* **2**: ``router_logits`` - 2 dimensional tensor of type *T* with the shape [batch, expert_num]. **Required.**


**Outputs**

* **1**: Output tensor of the same shape and type as the ``hidden_states`` input tensor.

**Types**

* *T*: any floating point type.

**Example**

.. code-block:: xml
:force:

<layer id="5" name="moe_router" type="MOE" version="ie_internal_opset">
<data config.topk="2" config.expert_num="4" config.hidden_size="2048" config.intermediate_size="768" config.group_size="128" config.fused_router_logic="1" config.weight_type="u4" config.scale_type="f16" config.zp_type="u4" expert0_mlp0.element_type="u4" expert0_mlp0.shape="768, 16, 128" expert0_mlp1.element_type="f16" expert0_mlp1.shape="768, 16, 1" expert0_mlp2.element_type="u4" expert0_mlp2.shape="768, 16, 1" expert1_mlp0.element_type="u4" expert1_mlp0.shape="768, 16, 128" expert1_mlp1.element_type="f16" expert1_mlp1.shape="768, 16, 1" expert1_mlp2.element_type="u4" expert1_mlp2.shape="768, 16, 1" expert2_mlp0.element_type="u4" expert2_mlp0.shape="768, 16, 128" expert2_mlp1.element_type="f16" expert2_mlp1.shape="768, 16, 1" expert2_mlp2.element_type="u4" expert2_mlp2.shape="768, 16, 1" expert3_mlp0.element_type="u4" expert3_mlp0.shape="768, 16, 128" expert3_mlp1.element_type="f16" expert3_mlp1.shape="768, 16, 1" expert3_mlp2.element_type="u4" expert3_mlp2.shape="768, 16, 1" />
<input>
<port id="0" precision="FP32">
<dim>-1</dim>
<dim>2048</dim>
</port>
<port id="1" precision="FP32">
<dim>-1</dim>
<dim>4</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>-1</dim>
<dim>2048</dim>
</port>
</output>
</layer>
91 changes: 91 additions & 0 deletions src/common/transformations/include/ov_ops/moe.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <array>
#include <memory>

#include "openvino/core/node.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov::op::internal {
///
/// \brief MOE experts
class TRANSFORMATIONS_API MOE : public ov::op::Op {
public:
OPENVINO_OP("MOE", "ie_internal_opset");

MOE() = default;

struct Config {
size_t topk{};
size_t expert_num{};
size_t hidden_size{};
size_t intermediate_size{};
size_t group_size{}; // quantized group size, 0 for no group size. same for gate/up/down
ov::element::Type weight_type{}; // same for gate/up/down
ov::element::Type scale_type{}; // same for gate/up/down
ov::element::Type zp_type{}; // same for gate/up/down
bool operator==(const Config& rhs) const {
return std::tie(topk,
expert_num,
hidden_size,
intermediate_size,
group_size,
weight_type,
scale_type,
zp_type) == std::tie(rhs.topk,
rhs.expert_num,
rhs.hidden_size,
rhs.intermediate_size,
rhs.group_size,
rhs.weight_type,
rhs.scale_type,
rhs.zp_type);
}
};

// 0: weight, 1: scale, 2: zp
struct ConstsPerExpert {
std::array<std::shared_ptr<ov::op::v0::Constant>, 3> gates;
std::array<std::shared_ptr<ov::op::v0::Constant>, 3> ups;
std::array<std::shared_ptr<ov::op::v0::Constant>, 3> downs;
};
struct Attributes {
// expert config
Config config;
// expert weight/scale/zp
std::vector<ConstsPerExpert> consts;
};

MOE(const OutputVector& args, const Attributes& attrs);

const Config& get_config() const;
void set_config(const Config& config);
const std::vector<ConstsPerExpert>& get_consts() const {
return m_attrs.consts;
}

void add_consts(size_t expert_no, const ConstsPerExpert& consts) {
OPENVINO_ASSERT(expert_no == m_attrs.consts.size(),
"MOE add_consts failed. Expected expert number: ",
m_attrs.consts.size(),
", current: ",
expert_no);
m_attrs.consts.push_back(consts);
}

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

private:
Attributes m_attrs;
};

} // namespace ov::op::internal
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API FuseMOEExpert;
class TRANSFORMATIONS_API FuseMOERouter;
class TRANSFORMATIONS_API FuseMOE;

} // namespace pass
} // namespace ov

class ov::pass::FuseMOEExpert : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("FuseMOE");
FuseMOEExpert();
};

class ov::pass::FuseMOERouter : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("FuseMOERouter");
FuseMOERouter();
};

class ov::pass::FuseMOE : public ov::pass::ModelPass {
public:
OPENVINO_MODEL_PASS_RTTI("FuseMOE");
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
};
67 changes: 67 additions & 0 deletions src/common/transformations/src/ov_ops/moe.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_ops/moe.hpp"

#include "itt.hpp"

namespace ov {
namespace op {
namespace internal {

MOE::MOE(const OutputVector& args, const Attributes& attrs) : Op(args), m_attrs(attrs) {
constructor_validate_and_infer_types();
}

const MOE::Config& MOE::get_config() const {
return m_attrs.config;
}

void MOE::set_config(const Config& config) {
m_attrs.config = config;
}

std::shared_ptr<ov::Node> MOE::clone_with_new_inputs(const ov::OutputVector& new_args) const {
INTERNAL_OP_SCOPE(internal_MOE_clone_with_new_inputs);
check_new_args_count(this, new_args);

return std::make_shared<MOE>(new_args, m_attrs);
}

void MOE::validate_and_infer_types() {
INTERNAL_OP_SCOPE(internal_MOE_validate_and_infer_types);
OPENVINO_ASSERT(get_input_size() == 2, "MOE must have 2 inputs whereas it has ", get_input_size());

set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}

bool MOE::visit_attributes(ov::AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(internal_MOE_visit_attributes);
visitor.start_structure("config");

visitor.on_attribute("topk", m_attrs.config.topk);
visitor.on_attribute("expert_num", m_attrs.config.expert_num);
visitor.on_attribute("hidden_size", m_attrs.config.hidden_size);
visitor.on_attribute("intermediate_size", m_attrs.config.intermediate_size);
visitor.on_attribute("group_size", m_attrs.config.group_size);
visitor.on_attribute("weight_type", m_attrs.config.weight_type);
visitor.on_attribute("scale_type", m_attrs.config.scale_type);
visitor.on_attribute("zp_type", m_attrs.config.zp_type);
visitor.finish_structure();
m_attrs.consts.resize(m_attrs.config.expert_num);
for (size_t i = 0; i < m_attrs.config.expert_num; i++) {
for (size_t j = 0; j < 3; j++) {
if (m_attrs.consts[i].gates[j]) {
visitor.start_structure("expert" + std::to_string(i) + "_mlp" + std::to_string(j));
m_attrs.consts[i].gates[j]->visit_attributes(visitor);
visitor.finish_structure();
}
}
}
return true;
}

} // namespace internal
} // namespace op
} // namespace ov
Loading
Loading