From b6e4798d2f17c04184939002ac4d2fdb8b126f9e Mon Sep 17 00:00:00 2001 From: Donghak PARK Date: Tue, 2 Dec 2025 14:20:45 +0900 Subject: [PATCH 1/6] [Application][CausalLM] Implement Ernie 4.5 MoE Model Implemnet Ernie 4.5 MoE Model - ernie's first layer is dense - ernie has shared expert at each MoE Layer **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK --- Applications/CausalLM/ernie_causallm.cpp | 156 +++++++++++++++++++++++ Applications/CausalLM/ernie_causallm.h | 85 ++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 Applications/CausalLM/ernie_causallm.cpp create mode 100644 Applications/CausalLM/ernie_causallm.h diff --git a/Applications/CausalLM/ernie_causallm.cpp b/Applications/CausalLM/ernie_causallm.cpp new file mode 100644 index 0000000000..a43eb2b426 --- /dev/null +++ b/Applications/CausalLM/ernie_causallm.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * + * @file ernie_causallm.h + * @brief ernie 4.5 causallm header + * @date 02 December 2025 + * @see https://github.com/nnstreamer/nntrainer + * @author Donghak Park + * @bug No known bugs except for NYI items + */ + +#include +#include +#include +#include +#include + +namespace causallm { + +std::vector +Ernie4_5_MoeForCausalLM::createMlp(const int layer_id, int dim, int hidden_dim, + std::string input_name) { + std::vector layers; + if (layer_id == 0) { + int ffn_hidden_dim = INTERMEDIATE_SIZE; // Ernie's first layer + + layers.push_back(createLayer( + "fully_connected", + {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_up"), + withKey("unit", ffn_hidden_dim), withKey("disable_bias", "true"), + withKey("input_layers", input_name), + withKey("weight_initializer", "ones")})); + + layers.push_back(createLayer( + "fully_connected", + {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_gate"), + withKey("unit", ffn_hidden_dim), withKey("disable_bias", "true"), + withKey("input_layers", input_name), + withKey("weight_initializer", "ones")})); + + layers.push_back(createLayer( + "swiglu", + {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_swiglu"), + withKey("input_layers", "layer" + std::to_string(layer_id) + "_ffn_up," + + "layer" + std::to_string(layer_id) + + "_ffn_gate")})); + + layers.push_back(createLayer( + "fully_connected", + {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"), + withKey("unit", dim), withKey("disable_bias", "true"), + withKey("input_layers", + "layer" + std::to_string(layer_id) + "_ffn_swiglu"), + withKey("weight_initializer", "ones")})); + + } else { + layers.push_back(createLayer( + "ernie_moe", + {withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"), + withKey("input_layers", input_name), + withKey("unit", MOE_INTERMEDIATE_SIZE), + withKey("num_experts", NUM_EXPERTS), + withKey("num_shared_experts", NUM_SHARED_EXPERTS), + withKey("num_experts_per_token", NUM_EXPERTS_PER_TOK), + withKey("moe_norm_min", std::to_string(MOE_NORM_MIN)), + withKey("moe_activation", "swish")})); + } + return layers; +} +std::vector Ernie4_5_MoeForCausalLM::createAttention( + const int layer_id, int seq_len, int n_heads, int head_dim, + std::string query_name, std::string key_name, std::string value_name) { + + std::vector layers; + auto Q = "layer" + std::to_string(layer_id) + "_wq"; + auto K = "layer" + std::to_string(layer_id) + "_wk"; + auto V = "layer" + std::to_string(layer_id) + "_wv"; + auto A = "layer" + std::to_string(layer_id) + "_attention"; + auto O = "layer" + std::to_string(layer_id) + "_attention_out"; + + // V layer + std::vector v_params = { + withKey("name", V), withKey("unit", head_dim * n_heads / GQA_SIZE), + withKey("disable_bias", "true"), withKey("input_layers", value_name), + withKey("weight_initializer", "ones")}; + layers.push_back(createLayer("fully_connected", v_params)); + + // K layer + std::vector k_params = { + withKey("name", K), withKey("unit", head_dim * n_heads / GQA_SIZE), + withKey("disable_bias", "true"), withKey("input_layers", key_name), + withKey("weight_initializer", "ones")}; + layers.push_back(createLayer("fully_connected", k_params)); + + // Q layer + std::vector q_params = { + withKey("name", Q), withKey("unit", head_dim * n_heads), + withKey("disable_bias", "true"), withKey("input_layers", query_name), + withKey("weight_initializer", "ones")}; + layers.push_back(createLayer("fully_connected", q_params)); + + // Attention core layer + std::vector a_params = { + withKey("name", A), + withKey("num_heads", n_heads), + withKey("num_heads_kv", n_heads / GQA_SIZE), + withKey("max_timestep", std::to_string(INIT_SEQ_LEN + NUM_TO_GENERATE)), + withKey("sliding_window", SLIDING_WINDOW), + withKey("rope_theta", ROPE_THETA), + withKey("max_position_embeddings", MAX_POSITION_EMBEDDINGS), + withKey("max_new_tokens", std::to_string(NUM_TO_GENERATE)), + withKey("input_layers", {Q, K, V})}; + layers.push_back(createLayer("mha_core", a_params)); + + // O layer + std::vector o_params = { + withKey("name", O), withKey("unit", DIM), withKey("disable_bias", "true"), + withKey("input_layers", A), withKey("weight_initializer", "ones")}; + layers.push_back(createLayer("fully_connected", o_params)); + + return layers; +} + +void Ernie4_5_MoeForCausalLM::setupParameters(json &cfg, json &generation_cfg, + json &nntr_cfg) { + + try { + NUM_EXPERTS = cfg["moe_num_experts"].get(); + NUM_EXPERTS_PER_TOK = cfg["num_experts_per_tok"].get(); + MOE_INTERMEDIATE_SIZE = cfg["moe_intermediate_size"].get(); + INTERMEDIATE_SIZE = cfg["moe_intermediate_size"].get(); + NUM_SHARED_EXPERTS = cfg["moe_num_shared_experts"].get(); + MOE_NORM_MIN = + cfg.contains("moe_norm_min") ? cfg["moe_norm_min"].get() : 1e-12f; + + } catch (const std::exception &e) { + throw std::runtime_error("Ernie Causallm: config parsing error"); + } +} + +void Ernie4_5_MoeForCausalLM::registerCustomLayers() { + CausalLM::registerCustomLayers(); + auto &ct_engine = nntrainer::Engine::Global(); + auto app_context = + static_cast(ct_engine.getRegisteredContext("cpu")); + + try { + app_context->registerFactory( + nntrainer::createLayer); + } catch (std::invalid_argument &e) { + std::cerr << "failed to register factory, reason: " << e.what() + << std::endl; + } +} + +} // namespace causallm \ No newline at end of file diff --git a/Applications/CausalLM/ernie_causallm.h b/Applications/CausalLM/ernie_causallm.h new file mode 100644 index 0000000000..5a7b33a993 --- /dev/null +++ b/Applications/CausalLM/ernie_causallm.h @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * + * @file ernie_causallm.h + * @brief ernie 4.5 causallm header + * @date 02 December 2025 + * @see https://github.com/nnstreamer/nntrainer + * @author Donghak Park + * @bug No known bugs except for NYI items + */ + +#ifndef NNTRAINER_ERNIE_CAUSALLM_H +#define NNTRAINER_ERNIE_CAUSALLM_H +#include + +namespace causallm { + +class Ernie4_5_MoeForCausalLM : public CausalLM { +public: + static constexpr const char *architecture = "Ernie4_5_MoeForCausalLM"; + Ernie4_5_MoeForCausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) : + CausalLM(cfg, generation_cfg, nntr_cfg) { + setupParameters(cfg, generation_cfg, nntr_cfg); + } + + virtual ~Ernie4_5_MoeForCausalLM() = default; + + /** + * @brief MoE layer + */ + /** + * @brief Create MLP layer + * @param layer_id Layer ID + * @param dim Dimension + * @param hidden_dim Hidden dimension + * @param input_name Input name + * @return std::vector Vector of layer handles + */ + std::vector createMlp(const int layer_id, int dim, + int hidden_dim, + std::string input_name) override; + + /** + * @brief Create Attention layer + * @param layer_id Layer ID + * @param seq_len Sequence length + * @param n_heads Number of heads + * @param head_dim Head dimension + * @param query_name Query name + * @param key_name Key name + * @param value_name Value name + * @return std::vector Vector of layer handles + */ + std::vector createAttention(int layer_id, int seq_len, + int n_heads, int head_dim, + std::string query_name, + std::string key_name, + std::string value_name) override; + /** + * @brief Setup parameters for the model + * @param cfg Configuration json + * @param generation_cfg Generation configuration json + * @param nntr_cfg NNtrainer configuration json + */ + void setupParameters(json &cfg, json &generation_cfg, + json &nntr_cfg) override; + + /** + * @brief Register custom layers + */ + void registerCustomLayers() override; + +private: + unsigned int NUM_EXPERTS; /**< Number of experts */ + unsigned int NUM_EXPERTS_PER_TOK; /**< Number of experts per token */ + unsigned int NUM_SHARED_EXPERTS; /**< Number of shared experts */ + unsigned int MOE_INTERMEDIATE_SIZE; /**< MoE intermediate size */ + float MOE_NORM_MIN; /**< MoE normalization minimum */ + + std::vector LAYER_TYPES; /**< Layer types */ + float ATTENTION_ROPE_SCALING_FACTOR; /**< Attention RoPE scaling factor */ +}; + +} // namespace causallm +#endif // NNTRAINER_ERNIE_CAUSALLM_H From 23de9ff38ff79b16e97868afae961bce571bfc77 Mon Sep 17 00:00:00 2001 From: Donghak PARK Date: Tue, 2 Dec 2025 14:24:21 +0900 Subject: [PATCH 2/6] [CausalLM] Add causallm common properties add causallm common properties - num_shared_experts - moe_norm_min **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK --- .../layers/causallm_common_properties.h | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/Applications/CausalLM/layers/causallm_common_properties.h b/Applications/CausalLM/layers/causallm_common_properties.h index 80ad260a54..5c0a9e174a 100644 --- a/Applications/CausalLM/layers/causallm_common_properties.h +++ b/Applications/CausalLM/layers/causallm_common_properties.h @@ -66,6 +66,24 @@ class NumExpertsPerToken : public nntrainer::PositiveIntegerProperty { using prop_tag = nntrainer::uint_prop_tag; /**< property type */ }; +/** + * @brief NumSharedExpers, Number of shared experts property + */ +class NumSharedExperts : public nntrainer::Property { +public: + static constexpr const char *key = "num_shared_experts"; + using prop_tag = nntrainer::uint_prop_tag; +}; + +/** + * @brief MoENormMin, Minimum value for MoE normalization + */ +class MoENormMin : public nntrainer::Property { +public: + static constexpr const char *key = "moe_norm_min"; + using prop_tag = nntrainer::float_prop_tag; +}; + /** * @brief unit property, unit is used to measure how many weights are there * From 9ebcf24811b13d85cbbcf50c182c88a85cfdf368 Mon Sep 17 00:00:00 2001 From: Donghak PARK Date: Tue, 2 Dec 2025 15:19:23 +0900 Subject: [PATCH 3/6] [CausalLM] Implement Erine MoE Layer Implement Ernie MoE Layer - Shared Expert accum - static bias add **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK --- Applications/CausalLM/ernie_causallm.cpp | 4 +- Applications/CausalLM/ernie_causallm.h | 2 +- .../CausalLM/layers/ernie_moe_layer.cpp | 520 ++++++++++++++++++ .../CausalLM/layers/ernie_moe_layer.h | 154 ++++++ 4 files changed, 677 insertions(+), 3 deletions(-) create mode 100644 Applications/CausalLM/layers/ernie_moe_layer.cpp create mode 100644 Applications/CausalLM/layers/ernie_moe_layer.h diff --git a/Applications/CausalLM/ernie_causallm.cpp b/Applications/CausalLM/ernie_causallm.cpp index a43eb2b426..84fcb62165 100644 --- a/Applications/CausalLM/ernie_causallm.cpp +++ b/Applications/CausalLM/ernie_causallm.cpp @@ -1,11 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 /** * - * @file ernie_causallm.h + * @file ernie_causallm.cpp * @brief ernie 4.5 causallm header * @date 02 December 2025 * @see https://github.com/nnstreamer/nntrainer - * @author Donghak Park + * @author Donghak Park * @bug No known bugs except for NYI items */ diff --git a/Applications/CausalLM/ernie_causallm.h b/Applications/CausalLM/ernie_causallm.h index 5a7b33a993..f54109c364 100644 --- a/Applications/CausalLM/ernie_causallm.h +++ b/Applications/CausalLM/ernie_causallm.h @@ -5,7 +5,7 @@ * @brief ernie 4.5 causallm header * @date 02 December 2025 * @see https://github.com/nnstreamer/nntrainer - * @author Donghak Park + * @author Donghak Park * @bug No known bugs except for NYI items */ diff --git a/Applications/CausalLM/layers/ernie_moe_layer.cpp b/Applications/CausalLM/layers/ernie_moe_layer.cpp new file mode 100644 index 0000000000..e024c1e87e --- /dev/null +++ b/Applications/CausalLM/layers/ernie_moe_layer.cpp @@ -0,0 +1,520 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * + * @file ernie_moe_layer.cpp + * @brief ernie 4.5 causallm header + * @date 02 December 2025 + * @see https://github.com/nnstreamer/nntrainer + * @author Donghak Park + * @bug No known bugs except for NYI items + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace causallm { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +ErnieMoELayer::ErnieMoELayer() : + LayerImpl(), + num_experts(0), + num_shared_experts(0), + topk(0), + moe_props(props::NumExperts(), props::NumExpertsPerToken(), + nntrainer::props::Unit(), props::MoEActivation(), + props::NumSharedExperts(), props::MoENormMin()), + expert_gate_proj_indices({}), + expert_up_proj_indices({}), + expert_down_proj_indices({}), + loaded_expert_deque({}), + need_load({}), + gate_idx(std::numeric_limits::max()), + e_score_correction_bias_idx(std::numeric_limits::max()), + router_logits_idx(std::numeric_limits::max()), + expert_mask_idx(std::numeric_limits::max()) {} + +void ErnieMoELayer::finalize(nntrainer::InitLayerContext &context) { + // 1. Validate input/output dimensions + NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) + << "MoE layer only supports single input"; + + auto &weight_regularizer = + std::get(*layer_impl_props); + auto &weight_regularizer_constant = + std::get(*layer_impl_props); + auto &weight_initializer = + std::get(*layer_impl_props); + auto &weight_decay = + std::get(*layer_impl_props); + + // 2. Set output dimensions (same as input) + const auto &in_dim = context.getInputDimensions()[SINGLE_INOUT_IDX]; + const bool is_nchw = context.getFormat() == nntrainer::Tformat::NCHW; + std::vector output_dims(1); + output_dims[SINGLE_INOUT_IDX] = in_dim; + context.setOutputDimensions(output_dims); + + // 3. Get MoE properties + num_experts = std::get(moe_props).get(); + num_shared_experts = std::get(moe_props).get(); + topk = std::get(moe_props).get(); + float moe_norm_min = std::get(moe_props).get(); + if (moe_norm_min == 0.0f) { + moe_norm_min = 1e-12f; // Default value if not set + std::get(moe_props).set(moe_norm_min); + } + const unsigned int intermediate_size = + std::get(moe_props).get(); + const unsigned int hidden_size = in_dim.width(); // Feature dimension + + // activation function + if (std::get(moe_props).empty()) { + throw std::runtime_error("Activation type is not set for MoE layer"); + } + switch (context.getActivationDataType()) { + case ml::train::TensorDim::DataType::FP32: + acti_func.setActiFunc( + std::get(moe_props).get()); + break; + default: + throw std::runtime_error("Unsupported activation data type for MoE layer"); + } + + // 4. Initialie gate layer (router) + nntrainer::TensorDim gate_dim( + 1, is_nchw ? 1 : num_experts, is_nchw ? hidden_size : 1, + is_nchw ? num_experts : hidden_size, + nntrainer::TensorDim::TensorType(context.getFormat(), + nntrainer::TensorDim::DataType::FP32), + is_nchw ? 0b0011 : 0b0101); + + gate_idx = context.requestWeight( + gate_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, "gate", true); + + // 4-1. Initialize e_score_correction_bias + nntrainer::TensorDim e_score_correction_bias_dim( + 1, 1, 1, num_experts, + nntrainer::TensorDim::TensorType(context.getFormat(), + nntrainer::TensorDim::DataType::FP32), + is_nchw ? 0b0011 : 0b0101); + + e_score_correction_bias_idx = context.requestWeight( + e_score_correction_bias_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, "moe_statics", false); + + // 5. Initializer expert weights + expert_gate_proj_indices.reserve(num_experts); + expert_up_proj_indices.reserve(num_experts); + expert_down_proj_indices.reserve(num_experts); + + if (num_shared_experts > 0) { + nntrainer::TensorDim shared_expert_gate_dim( + 1, is_nchw ? 1 : num_shared_experts * intermediate_size, + is_nchw ? hidden_size : 1, + is_nchw ? num_shared_experts * intermediate_size : hidden_size, + nntrainer::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType()), + is_nchw ? 0b0011 : 0b0101); + + nntrainer::TensorDim shared_expert_down_dim( + 1, is_nchw ? 1 : hidden_size, + is_nchw ? num_shared_experts * intermediate_size : 1, + is_nchw ? hidden_size : num_shared_experts * intermediate_size, + nntrainer::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType()), + is_nchw ? 0b0011 : 0b0101); + + shared_up_proj_idx = context.requestWeight( + shared_expert_gate_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, "shared_experts_up", false); + + shared_gate_proj_idx = context.requestWeight( + shared_expert_gate_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, "shared_experts_gate", false); + + shared_down_proj_idx = context.requestWeight( + shared_expert_down_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, "shared_experts_down", false); + } + + nntrainer::TensorDim expert_gate_dim( + 1, is_nchw ? 1 : intermediate_size, is_nchw ? hidden_size : 1, + is_nchw ? intermediate_size : hidden_size, + nntrainer::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType()), + is_nchw ? 0b0011 : 0b0101); + + nntrainer::TensorDim expert_down_dim( + 1, is_nchw ? 1 : hidden_size, is_nchw ? intermediate_size : 1, + is_nchw ? hidden_size : intermediate_size, + nntrainer::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType()), + is_nchw ? 0b0011 : 0b0101); + + for (unsigned int i = 0; i < num_experts; ++i) { + expert_up_proj_indices.push_back(context.requestWeight( + expert_gate_dim, // Same dimensions as gate projection + weight_initializer, weight_regularizer, weight_regularizer_constant, + weight_decay, "expert_up_" + std::to_string(i), false, true)); + + expert_gate_proj_indices.push_back(context.requestWeight( + expert_gate_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, + "expert_gate_" + std::to_string(i), false, true)); + + expert_down_proj_indices.push_back(context.requestWeight( + expert_down_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, + "expert_down_" + std::to_string(i), false, true)); + + need_load.push_back(true); + } + + // 5-1. Initialize shared expert weights + + // 6. Request intermediate tensors + const unsigned batch_size = in_dim.batch(); + const unsigned seq_len = in_dim.height(); + const unsigned total_tokens = batch_size * seq_len; + + // Router logits : [batch * seq, num_experts] + router_logits_idx = + context.requestTensor({total_tokens, 1, 1, num_experts}, "router_logits", + nntrainer::Initializer::NONE, false, + nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); + + // Expert mask: [num_experts, batch*seq] + expert_mask_idx = + context.requestTensor({num_experts, 1, topk, total_tokens}, "expert_mask", + nntrainer::Initializer::ZEROS, false, + nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); +} + +void ErnieMoELayer::forwarding(nntrainer::RunLayerContext &context, + bool training) {} + +inline void ErnieMoELayer::compute_expert_forward( + const nntrainer::Tensor &input, nntrainer::Tensor &output, + const std::vector> &token_assignments, + const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, + const nntrainer::Tensor &down_proj, unsigned int hidden_size) { + + const unsigned intermediate_size = gate_proj.width(); + const unsigned num_tokens = token_assignments.size(); + + if (num_tokens == 0) + return; + + // Create tensor dimensions for single token processing + nntrainer::TensorDim token_input_dim({1, 1, num_tokens, hidden_size}, + input.getTensorType()); + nntrainer::TensorDim intermediate_dim({1, 1, num_tokens, intermediate_size}, + input.getTensorType()); + nntrainer::TensorDim token_output_dim({1, 1, num_tokens, hidden_size}, + input.getTensorType()); + nntrainer::TensorDim out_step_dim({1, 1, 1, hidden_size}, + input.getTensorType()); + nntrainer::TensorDim step_dim({1, 1, 1, intermediate_size}, + input.getTensorType()); + // Create intermediate tensors for this token + nntrainer::Tensor gate_out(intermediate_dim); + nntrainer::Tensor acti_out(intermediate_dim); + nntrainer::Tensor up_out(intermediate_dim); + nntrainer::Tensor token_input(token_input_dim); + // Down projection using optimized dot operation + nntrainer::Tensor token_expert_output(token_output_dim); + + unsigned token_idx = token_assignments[0].first; + float weight = token_assignments[0].second; + + if (num_tokens > 1) { + /** if prefill, copy data to make a batch */ +#pragma omp parallel for schedule(static) if (num_tokens > 4) + for (size_t i = 0; i < num_tokens; ++i) { + const unsigned token_idx = token_assignments[i].first; + // Use tensor's optimized copy operation + nntrainer::Tensor src_view = input.getSharedDataTensor( + {1, 1, 1, hidden_size}, token_idx * hidden_size, true); + nntrainer::Tensor dst_view = token_input.getSharedDataTensor( + {1, 1, 1, hidden_size}, i * hidden_size, true); + dst_view.copyData(src_view); + } + } else { + /** if token generation, do not copy but get the shared tensor */ + // Create shared tensor for input token (no memory copy) + size_t token_offset = token_idx * hidden_size; + token_input = + input.getSharedDataTensor(token_input_dim, token_offset, true); + } + + // Gate projection using optimized dot operation + token_input.dot(gate_proj, gate_out); + + // Up projection using optimized dot operation + token_input.dot(up_proj, up_out); + + if (num_tokens == 1) { + // Apply activation (silu) + acti_func.run_fn(gate_out, acti_out); + // Element-wise multiply: silu(gate_out) * up_out + acti_out.multiply_i(up_out); + } else { +#pragma omp parallel for schedule(static) if (num_tokens > 4) + for (size_t i = 0; i < num_tokens; ++i) { + const unsigned offset = acti_out.getIndex(0, 0, i, 0); + nntrainer::swiglu(acti_out.width(), acti_out.getData() + offset, + gate_out.getData() + offset, + up_out.getData() + offset); + } + } + + acti_out.dot(down_proj, token_expert_output); + + // accumulate to output + for (size_t i = 0; i < num_tokens; ++i) { + token_idx = token_assignments[i].first; + weight = token_assignments[i].second; + size_t output_offset = token_idx * hidden_size; + nntrainer::Tensor token_output = + output.getSharedDataTensor(out_step_dim, output_offset, true); + nntrainer::Tensor target = token_expert_output.getSharedDataTensor( + out_step_dim, i * hidden_size, true); + target.multiply_i(weight); + token_output.add(target, token_output); + } +} + +void ErnieMoELayer::incremental_forwarding(nntrainer::RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) { + + nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); + nntrainer::Tensor &output_ = context.getOutput(SINGLE_INOUT_IDX); + + nntrainer::Tensor &router_logits_ = context.getTensor(router_logits_idx); + + nntrainer::TensorDim input_step_dim = input_.getDim(); + nntrainer::TensorDim output_step_dim = output_.getDim(); + nntrainer::TensorDim router_logits_step_dim = router_logits_.getDim(); + nntrainer::Tensor &shared_gate_proj = context.getWeight(shared_gate_proj_idx); + nntrainer::Tensor &shared_up_proj = context.getWeight(shared_up_proj_idx); + nntrainer::Tensor &shared_down_proj = context.getWeight(shared_down_proj_idx); + + input_step_dim.batch(1); + output_step_dim.batch(1); + router_logits_step_dim.batch(to - from); + + input_step_dim.height(to - from); + output_step_dim.height(to - from); + + for (unsigned int b = 0; b < input_.batch(); ++b) { + auto input = input_.getSharedDataTensor( + input_step_dim, b * input_step_dim.getFeatureLen(), true); + auto output = output_.getSharedDataTensor( + output_step_dim, b * output_step_dim.getFeatureLen(), true); + auto router_logits = + router_logits_.getSharedDataTensor(router_logits_step_dim, 0, true); + + const unsigned batch_size = input.batch(); + const unsigned seq_len = input.height(); + const unsigned hidden_size = input.width(); + const unsigned total_tokens = batch_size * seq_len; + + // reshape input: [B,1,S,H] -> [B*S,1,1,H] + input.reshape({total_tokens, 1, 1, hidden_size}); + + // reshape output: [B,1,S,H] -> [B*S,1,1,H] + output.reshape({total_tokens, 1, 1, hidden_size}); + output.setZero(); + + // Compute shared experts + if (num_shared_experts > 0) { + + nntrainer::Tensor shared_output(total_tokens, 1, 1, hidden_size, + output.getTensorType()); + const unsigned int intermediate_size = + std::get(moe_props).get(); + + nntrainer::TensorDim intermediate_dim( + {total_tokens, 1, 1, num_shared_experts * intermediate_size}, + input.getTensorType()); + + nntrainer::Tensor gate_out(intermediate_dim); + nntrainer::Tensor acti_out(intermediate_dim); + nntrainer::Tensor up_out(intermediate_dim); + + input.dot(shared_gate_proj, gate_out); + input.dot(shared_up_proj, up_out); + + acti_func.run_fn(gate_out, acti_out); + acti_out.multiply_i(up_out); + + acti_out.dot(shared_down_proj, shared_output); + + // Add shared expert output to final output + output.add_i(shared_output); + } + + // routing + nntrainer::Tensor &gate_weights = context.getWeight(gate_idx); + input.dot(gate_weights, router_logits); + + router_logits.apply(nntrainer::ActiFunc::softmax, router_logits); + + // Add e_score_correction_bias + nntrainer::Tensor &e_score_correction_bias = + context.getWeight(e_score_correction_bias_idx); + + nntrainer::Tensor biased_router_logits(router_logits.getDim()); + biased_router_logits.copyData(router_logits); + biased_router_logits.add_i(e_score_correction_bias); + + auto topk_result = biased_router_logits.topK(topk); + auto topk_values = std::get<0>(topk_result); + auto topk_indices = std::get<1>(topk_result); + + const uint32_t *indices_data = topk_indices.getData(); + std::vector>> expert_assignments( + num_experts); + // Set expert mask + for (int i = 0; i < static_cast(total_tokens); ++i) { + float sum_prob = 0.0f; + for (int k = 0; k < static_cast(topk); ++k) { + unsigned expert_idx = indices_data[i * topk + k]; + float weight = router_logits.getValue(i, 0, 0, expert_idx); + sum_prob += weight; + } + + for (int k = 0; k < static_cast(topk); ++k) { + unsigned expert_idx = indices_data[i * topk + k]; + float weight = router_logits.getValue(i, 0, 0, expert_idx); + weight /= + std::max(sum_prob, std::get(moe_props).get()); + expert_assignments[expert_idx].emplace_back(i, weight); + } + } + + // Parallel processing for multiple tokens with many active experts + std::vector expert_outputs(num_experts); +#pragma omp parallel for schedule(static) + for (int expert_idx = 0; expert_idx < static_cast(num_experts); + ++expert_idx) { + if (!expert_assignments[expert_idx].empty()) { + expert_outputs[expert_idx] = nntrainer::Tensor( + total_tokens, 1, 1, hidden_size, output.getTensorType()); + expert_outputs[expert_idx].setZero(); + } + } + std::vector target_idx_vector; + + for (int expert_idx = 0; expert_idx < static_cast(num_experts); + ++expert_idx) { + const auto &assignments = expert_assignments[expert_idx]; + if (assignments.empty()) + continue; + + target_idx_vector.push_back(expert_idx); + } + +#pragma omp parallel for schedule(dynamic) + for (int expert_idx : target_idx_vector) { + const auto &assignments = expert_assignments[expert_idx]; + if (need_load[expert_idx]) { + + context.getWeight(expert_gate_proj_indices[expert_idx]).activate(); + context.getWeight(expert_up_proj_indices[expert_idx]).activate(); + context.getWeight(expert_down_proj_indices[expert_idx]).activate(); + + { + std::lock_guard lock(cache_mutex); + loaded_expert_deque.push_back(expert_idx); + iteration_map[expert_idx] = --loaded_expert_deque.end(); + need_load[expert_idx] = false; + } + + compute_expert_forward( + input, expert_outputs[expert_idx], assignments, + context.getWeight(expert_gate_proj_indices[expert_idx]), + context.getWeight(expert_up_proj_indices[expert_idx]), + context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); + + } else { + + compute_expert_forward( + input, expert_outputs[expert_idx], assignments, + context.getWeight(expert_gate_proj_indices[expert_idx]), + context.getWeight(expert_up_proj_indices[expert_idx]), + context.getWeight(expert_down_proj_indices[expert_idx]), hidden_size); + } + } + +// Evict experts +#pragma omp parallel + while (loaded_expert_deque.size() > 16) { + int target_idx; + { + std::lock_guard lock(cache_mutex); + target_idx = loaded_expert_deque.front(); + loaded_expert_deque.pop_front(); + iteration_map.erase(target_idx); + need_load[target_idx] = true; + } + context.getWeight(expert_gate_proj_indices[target_idx]).deactivate(); + context.getWeight(expert_up_proj_indices[target_idx]).deactivate(); + context.getWeight(expert_down_proj_indices[target_idx]).deactivate(); + } + + // Combine expert outputs + for (int expert_idx : target_idx_vector) { + output.add_i(expert_outputs[expert_idx]); + } + + // reshape output: [B*S,1,1,H] -> [B,1,S,H] + output.reshape({batch_size, 1, seq_len, hidden_size}); + } +} + +void ErnieMoELayer::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, moe_props); + nntrainer::LayerImpl::setProperty(remain_props); +} + +void ErnieMoELayer::calcDerivative(nntrainer::RunLayerContext &context) { + // MoE layer does not support derivative calculation + throw std::runtime_error("MoE layer does not support derivative calculation"); +} + +void ErnieMoELayer::calcGradient(nntrainer::RunLayerContext &context) { + // MoE layer does not support gradient calculation + throw std::runtime_error("MoE layer does not support gradient calculation"); +} + +void ErnieMoELayer::exportTo(nntrainer::Exporter &exporter, + const ml::train::ExportMethods &method) const { + nntrainer::LayerImpl::exportTo(exporter, method); + exporter.saveResult(moe_props, method, this); // Save MoE specific properties +} + +void ErnieMoELayer::updateTensorsByInputDimensions( + nntrainer::RunLayerContext &context, + std::vector input_dimensions) { + ml::train::TensorDim input_dim = context.getInput(SINGLE_INOUT_IDX).getDim(); + ml::train::TensorDim output_dim = + context.getOutput(SINGLE_INOUT_IDX).getDim(); + + input_dim.height(input_dimensions[0].height()); + output_dim.height(input_dimensions[0].height()); + + context.updateInput(SINGLE_INOUT_IDX, input_dim); + context.updateOutput(SINGLE_INOUT_IDX, output_dim); +} + +} // namespace causallm \ No newline at end of file diff --git a/Applications/CausalLM/layers/ernie_moe_layer.h b/Applications/CausalLM/layers/ernie_moe_layer.h new file mode 100644 index 0000000000..6d2a076f3d --- /dev/null +++ b/Applications/CausalLM/layers/ernie_moe_layer.h @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * + * @file ernie_moe_layer.h + * @brief ernie 4.5 moe layer header + * @date 02 December 2025 + * @see https://github.com/nnstreamer/nntrainer + * @author Donghak Park + * @bug No known bugs except for NYI items + */ + +#ifndef NNTRAINER_ERNIE_MOE_LAYER_H +#define NNTRAINER_ERNIE_MOE_LAYER_H +#ifdef __cplusplus + +#include +#include +#include +#include +#include + +namespace causallm { + +class ErnieMoELayer : public nntrainer::LayerImpl { +public: + ErnieMoELayer(); + + /** + * @brief Destructor of Mixture of Expert Layer + */ + ~ErnieMoELayer() = default; + + /** + * @brief Move constructor. + * @param[in] ErnieMoELayer && + */ + ErnieMoELayer(ErnieMoELayer &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @param[in] rhs ErnieMoELayer to be moved. + */ + ErnieMoELayer &operator=(ErnieMoELayer &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(nntrainer::InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(nntrainer::RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned) + */ + void incremental_forwarding(nntrainer::RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(nntrainer::RunLayerContext &context) override; + + /** + * @copydoc Layer::calcGradient(RunLayerContext &context) + */ + void calcGradient(nntrainer::RunLayerContext &context) override; + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, const ml::train::ExportMethods + * &methods) + */ + void exportTo(nntrainer::Exporter &exporter, + const ml::train::ExportMethods &method) const override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return ErnieMoELayer::type; }; + + /** + * @brief Layer::supportBackwarding() + */ + bool supportBackwarding() const override { return false; } + + /** + * @brief Update tensors by input dimensions + * @param context RunLayerContext + * @param input_dimensions Input dimensions + */ + WIN_EXPORT void updateTensorsByInputDimensions( + nntrainer::RunLayerContext &context, + std::vector input_dimensions) override; + + static constexpr const char *type = "ernie_moe"; /**< type of the layer */ + +private: + unsigned int num_experts; /**< number of experts */ + unsigned int num_shared_experts; /**< number of shared experts */ + unsigned int topk; /**< number of experts per token, i.e., topk */ + nntrainer::ActiFunc acti_func; /**< activation function for the expert */ + std::tuple + moe_props; + + // weight indices + std::vector expert_gate_proj_indices; + std::vector expert_up_proj_indices; + std::vector expert_down_proj_indices; + unsigned int shared_gate_proj_idx; + unsigned int shared_up_proj_idx; + unsigned int shared_down_proj_idx; + + std::list loaded_expert_deque; + std::unordered_map::iterator> iteration_map; + std::unordered_map expert_predict_scores; + std::vector need_load; + std::mutex cache_mutex; + + unsigned int gate_idx; + unsigned int e_score_correction_bias_idx; + + unsigned int router_logits_idx; + unsigned int expert_mask_idx; + + /** + * @brief Compute expert forward pass + * @param input Input tensor + * @param output Output tensor + * @param token_assignments Token assignments + * @param gate_proj Gate projection weight + * @param up_proj Up projection weight + * @param down_proj Down projection weight + * @param hidden_size Hidden size + */ + inline void compute_expert_forward( + const nntrainer::Tensor &input, nntrainer::Tensor &output, + const std::vector> &token_assignments, + const nntrainer::Tensor &gate_proj, const nntrainer::Tensor &up_proj, + const nntrainer::Tensor &down_proj, unsigned int hidden_size); +}; +} // namespace causallm + +#endif /* __cplusplus */ +#endif /* NNTRAINER_ERNIE_MOE_LAYER_H */ \ No newline at end of file From 696470c08cc58d775ee49f034b6a501440260f99 Mon Sep 17 00:00:00 2001 From: Donghak PARK Date: Tue, 2 Dec 2025 15:25:10 +0900 Subject: [PATCH 4/6] [CausalLM] Add ernie to main & meson build Add ernie model & Layer to main, meson build **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK --- Applications/CausalLM/layers/meson.build | 16 +++++++++++++++- Applications/CausalLM/main.cpp | 10 +++++++++- Applications/CausalLM/meson.build | 2 ++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/Applications/CausalLM/layers/meson.build b/Applications/CausalLM/layers/meson.build index 58f462cd52..ce7becb4ad 100644 --- a/Applications/CausalLM/layers/meson.build +++ b/Applications/CausalLM/layers/meson.build @@ -13,7 +13,7 @@ causallm_cached_slim_moe_layer_src_abs = [meson.current_source_dir() / 'qwen_moe causallm_qkv_layer_src_abs = [meson.current_source_dir() / 'qkv_layer.cpp'] causallm_gptoss_moe_layer_src_abs = [meson.current_source_dir() / 'gpt_oss_moe_layer.cpp'] causallm_gptoss_moe_layer_cached_src_abs = [meson.current_source_dir() / 'gpt_oss_moe_layer_cached.cpp'] - +causallm_ernie_moe_layer_src_abs = [meson.current_source_dir() / 'ernie_moe_layer.cpp'] openmp_dep = dependency('openmp') @@ -175,3 +175,17 @@ causallm_cached_slim_gpt_oss_moe_layer_dep = declare_dependency( link_with: causallm_cached_slim_gptoss_moe_layer, include_directories: causallm_layer_inc ) + +causallm_ernie_moe_layer = shared_library( + 'ernie_moe_layer', + causallm_ernie_moe_layer_src_abs, + include_directories: causallm_layer_inc, + dependencies: [nntrainer_dep, nntrainer_ccapi_dep], + install: true, + install_dir: application_install_dir +) + +causallm_ernie_moe_layer_dep = declare_dependency( + link_with: causallm_ernie_moe_layer, + include_directories: causallm_layer_inc +) \ No newline at end of file diff --git a/Applications/CausalLM/main.cpp b/Applications/CausalLM/main.cpp index 5a12735f33..377d0026f9 100644 --- a/Applications/CausalLM/main.cpp +++ b/Applications/CausalLM/main.cpp @@ -30,6 +30,7 @@ #include #include "causal_lm.h" +#include "ernie_causallm.h" #include "gptoss_cached_slim_causallm.h" #include "gptoss_causallm.h" #include "nntr_qwen3_causallm.h" @@ -38,10 +39,10 @@ #include "qwen3_causallm.h" #include "qwen3_moe_causallm.h" #include "qwen3_slim_moe_causallm.h" -#include #include #include +#include #include using json = nlohmann::json; @@ -155,6 +156,13 @@ int main(int argc, char *argv[]) { cfg, generation_cfg, nntr_cfg); }); + causallm::Factory::Instance().registerModel( + "Ernie4_5_MoeForCausalLM", + [](json cfg, json generation_cfg, json nntr_cfg) { + return std::make_unique( + cfg, generation_cfg, nntr_cfg); + }); + // Validate arguments if (argc < 2) { std::cerr << "Usage: " << argv[0] << " [input_prompt]\n" diff --git a/Applications/CausalLM/meson.build b/Applications/CausalLM/meson.build index 4ec2ee26ac..e364b90f4a 100644 --- a/Applications/CausalLM/meson.build +++ b/Applications/CausalLM/meson.build @@ -13,6 +13,7 @@ causallm_src += [ meson.current_source_dir() / 'nntr_qwen3_moe_causallm.cpp', meson.current_source_dir() / 'gptoss_causallm.cpp', meson.current_source_dir() / 'gptoss_cached_slim_causallm.cpp', + meson.current_source_dir() / 'ernie_causallm.cpp', ] executable_src = [ @@ -39,6 +40,7 @@ causallm_layer_dependencies = [ causallm_qkv_layer_dep, casuallm_gptoss_moe_layer_dep, causallm_cached_slim_gpt_oss_moe_layer_dep, + causallm_ernie_moe_layer_dep, ] if (get_option('platform') == 'windows') and (build_machine.system() == 'windows') From 4e8a12dde1c1919d46e424b6b6178acdfcb5c7de Mon Sep 17 00:00:00 2001 From: Donghak PARK Date: Tue, 2 Dec 2025 15:35:31 +0900 Subject: [PATCH 5/6] [CausalLM] Implement ERNIE's GLM Style RoPE Implement GLM Sytle RoPE at MHA CORE **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK --- Applications/CausalLM/ernie_causallm.cpp | 2 +- Applications/CausalLM/ernie_causallm.h | 8 +- .../CausalLM/layers/ernie_moe_layer.cpp | 2 +- .../CausalLM/layers/ernie_moe_layer.h | 9 +- Applications/CausalLM/layers/meson.build | 2 +- Applications/CausalLM/layers/mha_core.cpp | 164 ++++++++++++++++++ Applications/CausalLM/layers/mha_core.h | 21 +++ Applications/CausalLM/meson.build | 2 +- 8 files changed, 203 insertions(+), 7 deletions(-) diff --git a/Applications/CausalLM/ernie_causallm.cpp b/Applications/CausalLM/ernie_causallm.cpp index 84fcb62165..4eccd65904 100644 --- a/Applications/CausalLM/ernie_causallm.cpp +++ b/Applications/CausalLM/ernie_causallm.cpp @@ -153,4 +153,4 @@ void Ernie4_5_MoeForCausalLM::registerCustomLayers() { } } -} // namespace causallm \ No newline at end of file +} // namespace causallm diff --git a/Applications/CausalLM/ernie_causallm.h b/Applications/CausalLM/ernie_causallm.h index f54109c364..6ad9653b91 100644 --- a/Applications/CausalLM/ernie_causallm.h +++ b/Applications/CausalLM/ernie_causallm.h @@ -15,6 +15,10 @@ namespace causallm { +/** + * @class Ernie4_5_MoeForCausalLM + * @brief Mixture of Expert Layer for ERNIE 4.5 + */ class Ernie4_5_MoeForCausalLM : public CausalLM { public: static constexpr const char *architecture = "Ernie4_5_MoeForCausalLM"; @@ -77,8 +81,8 @@ class Ernie4_5_MoeForCausalLM : public CausalLM { unsigned int MOE_INTERMEDIATE_SIZE; /**< MoE intermediate size */ float MOE_NORM_MIN; /**< MoE normalization minimum */ - std::vector LAYER_TYPES; /**< Layer types */ - float ATTENTION_ROPE_SCALING_FACTOR; /**< Attention RoPE scaling factor */ + std::vector LAYER_TYPES; /**< Layer types */ + float ATTENTION_ROPE_SCALING_FACTOR; /**< Attention RoPE scaling factor */ }; } // namespace causallm diff --git a/Applications/CausalLM/layers/ernie_moe_layer.cpp b/Applications/CausalLM/layers/ernie_moe_layer.cpp index e024c1e87e..44f8b62746 100644 --- a/Applications/CausalLM/layers/ernie_moe_layer.cpp +++ b/Applications/CausalLM/layers/ernie_moe_layer.cpp @@ -517,4 +517,4 @@ void ErnieMoELayer::updateTensorsByInputDimensions( context.updateOutput(SINGLE_INOUT_IDX, output_dim); } -} // namespace causallm \ No newline at end of file +} // namespace causallm diff --git a/Applications/CausalLM/layers/ernie_moe_layer.h b/Applications/CausalLM/layers/ernie_moe_layer.h index 6d2a076f3d..796dd84960 100644 --- a/Applications/CausalLM/layers/ernie_moe_layer.h +++ b/Applications/CausalLM/layers/ernie_moe_layer.h @@ -21,8 +21,15 @@ namespace causallm { +/** + * @class ErnieMoELayer + * @brief Mixture of Expert Layer for ERNIE 4.5 + */ class ErnieMoELayer : public nntrainer::LayerImpl { public: + /** + * @brief Constructor of Mixture of Expert Layer + */ ErnieMoELayer(); /** @@ -151,4 +158,4 @@ class ErnieMoELayer : public nntrainer::LayerImpl { } // namespace causallm #endif /* __cplusplus */ -#endif /* NNTRAINER_ERNIE_MOE_LAYER_H */ \ No newline at end of file +#endif /* NNTRAINER_ERNIE_MOE_LAYER_H */ diff --git a/Applications/CausalLM/layers/meson.build b/Applications/CausalLM/layers/meson.build index ce7becb4ad..46dedb838c 100644 --- a/Applications/CausalLM/layers/meson.build +++ b/Applications/CausalLM/layers/meson.build @@ -188,4 +188,4 @@ causallm_ernie_moe_layer = shared_library( causallm_ernie_moe_layer_dep = declare_dependency( link_with: causallm_ernie_moe_layer, include_directories: causallm_layer_inc -) \ No newline at end of file +) diff --git a/Applications/CausalLM/layers/mha_core.cpp b/Applications/CausalLM/layers/mha_core.cpp index 9d9123c00f..85f3d29d5b 100644 --- a/Applications/CausalLM/layers/mha_core.cpp +++ b/Applications/CausalLM/layers/mha_core.cpp @@ -1138,4 +1138,168 @@ nntrainer::LayerPluggable ml_train_layer_pluggable{create_mha_core_layer, #endif +void MHACoreLayer::precompute_freqs_ernie(int head_dim, unsigned int seq_len, + float theta) { + // compute the freqs only when it is the first time to call this function + if (freqs_cos != nullptr && freqs_cos->size() == seq_len) + return; + if (rope_scaling_type == "default") + _compute_default_parameters(head_dim, theta); + else if (rope_scaling_type == "yarn") + _compute_yarn_parameters(head_dim, theta); + else + NNTR_THROW_IF(true, std::invalid_argument) << "Unsupported rope type!"; + // cos / sin + unsigned int half_ = head_dim / 2; + auto cos = new std::vector>(); + cos->assign(seq_len, std::vector(head_dim, 0)); + auto sin = new std::vector>(); + sin->assign(seq_len, std::vector(head_dim, 0)); + // update cos / sin frequency + for (unsigned int i = 0; i < seq_len; ++i) { +#ifdef USE_NEON + nntrainer::calc_trigonometric_vals_dup(half_, thetas.data(), + (*cos)[i].data(), (*sin)[i].data(), + i, attention_scaling); +#else + for (unsigned int j = 0; j < half_; ++j) { + double angle = (double)i * thetas[j]; + (*cos)[i][2 * j] = (float)(std::cos(angle) * (double)attention_scaling); + (*cos)[i][2 * j + 1] = + (float)(std::cos(angle) * + (double)attention_scaling); // repeated 2 times + + (*sin)[i][2 * j] = (float)(std::sin(angle) * (double)attention_scaling); + (*sin)[i][2 * j + 1] = + (float)(std::sin(angle) * + (double)attention_scaling); // repeated 2 times + } +#endif + } + freqs_cos = cos; + freqs_sin = sin; +#ifdef ENABLE_FP16 + // cos / sin for FP16 + auto cos_fp16 = new std::vector>(); + cos_fp16->assign(seq_len, std::vector<_FP16>(head_dim, 0)); + auto sin_fp16 = new std::vector>(); + sin_fp16->assign(seq_len, std::vector<_FP16>(head_dim, 0)); + for (unsigned int i = 0; i < seq_len; ++i) { + for (unsigned int j = 0; j < head_dim; ++j) { + (*cos_fp16)[i][j] = (_FP16)(*cos)[i][j]; + (*sin_fp16)[i][j] = (_FP16)(*sin)[i][j]; + } + } + freqs_cos_fp16 = cos_fp16; + freqs_sin_fp16 = sin_fp16; +#endif +}; + +void MHACoreLayer::apply_rotary_emb_tensor_ernie(nntrainer::Tensor &in, + nntrainer::Tensor &out, + unsigned int dim, + unsigned int from, + bool convert_only) { + unsigned int half_ = dim / 2; + unsigned int max_timestep = + std::get(mha_core_props).get(); + if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { + std::vector *cos_ = nullptr; + std::vector *sin_ = nullptr; + for (unsigned int b = 0; b < in.batch(); b++) { + for (unsigned int c = 0; c < in.channel(); c++) { + for (unsigned int h = 0; h < in.height(); h++) { + if (from < max_timestep) { + if (from + h >= freqs_cos->size()) { + throw std::runtime_error( + "RoPE index out of bounds: " + std::to_string(from + h) + + " >= " + std::to_string(freqs_cos->size())); + } + cos_ = &(*freqs_cos)[from + h]; + sin_ = &(*freqs_sin)[from + h]; + } + float *in_ptr = in.getData() + + b * in.channel() * in.height() * in.width() + + c * in.height() * in.width() + h * in.width(); + + if (out.getDataType() == ml::train::TensorDim::DataType::FP32) { + float *out_ptr = out.getData() + + b * out.channel() * out.height() * out.width() + + c * out.height() * out.width() + h * out.width(); + for (unsigned int w = 0; w < in.width(); w += dim) { + for (unsigned int i = 0; i < dim; i += 2) { + float in0 = in_ptr[w + i]; + float in1 = in_ptr[w + i + 1]; + float c = (*cos_)[i]; + float s = (*sin_)[i]; + out_ptr[w + i] = in0 * c - in1 * s; + out_ptr[w + i + 1] = in1 * c + in0 * s; + } + } + } else if (out.getDataType() == + ml::train::TensorDim::DataType::UINT16 || + out.getDataType() == + ml::train::TensorDim::DataType::FP16) { + uint16_t *out_ptr = out.getData() + + b * out.channel() * out.height() * out.width() + + c * out.height() * out.width() + + h * out.width(); + for (unsigned int w = 0; w < in.width(); w += dim) { + for (unsigned int i = 0; i < dim; i += 2) { + float in0 = in_ptr[w + i]; + float in1 = in_ptr[w + i + 1]; + float c = (*cos_)[i]; + float s = (*sin_)[i]; + float out0 = in0 * c - in1 * s; + float out1 = in1 * c + in0 * s; + out_ptr[w + i] = nntrainer::compute_fp32_to_fp16(out0); + out_ptr[w + i + 1] = nntrainer::compute_fp32_to_fp16(out1); + } + } + } + } + } + } + } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + std::vector<_FP16> *cos_ = nullptr; + std::vector<_FP16> *sin_ = nullptr; + for (unsigned int b = 0; b < in.batch(); b++) { + for (unsigned int c = 0; c < in.channel(); c++) { + for (unsigned int h = 0; h < in.height(); h++) { + if (from < max_timestep) { + if (from + h >= freqs_cos_fp16->size()) { + throw std::runtime_error( + "RoPE index out of bounds (FP16): " + std::to_string(from + h) + + " >= " + std::to_string(freqs_cos_fp16->size())); + } + cos_ = &(*freqs_cos_fp16)[from + h]; + sin_ = &(*freqs_sin_fp16)[from + h]; + } + _FP16 *in_ptr = in.getData<_FP16>() + + b * in.channel() * in.height() * in.width() + + c * in.height() * in.width() + h * in.width(); + _FP16 *out_ptr = out.getData<_FP16>() + + b * out.channel() * out.height() * out.width() + + c * out.height() * out.width() + h * out.width(); + + for (unsigned int w = 0; w < in.width(); w += dim) { + for (unsigned int i = 0; i < dim; i += 2) { + float in0 = (float)in_ptr[w + i]; + float in1 = (float)in_ptr[w + i + 1]; + float c = (float)(*cos_)[i]; + float s = (float)(*sin_)[i]; + out_ptr[w + i] = (_FP16)(in0 * c - in1 * s); + out_ptr[w + i + 1] = (_FP16)(in1 * c + in0 * s); + } + } + } + } + } +#else + NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!"; +#endif + } +} + } // namespace causallm diff --git a/Applications/CausalLM/layers/mha_core.h b/Applications/CausalLM/layers/mha_core.h index e9225c259c..3770899fde 100644 --- a/Applications/CausalLM/layers/mha_core.h +++ b/Applications/CausalLM/layers/mha_core.h @@ -422,6 +422,27 @@ WIN_EXPORT class MHACoreLayer : public nntrainer::LayerImpl { */ void calcCommonDerivative(nntrainer::RunLayerContext &context); + /** + * @brief pre_compute frequencies for Rotary Embedding for ERNIE. + * @note it is expected to be called only once at the finalize. + * @param[in] head_dim dimension of head + * @param[in] seq_len sequence length + * @param[in] theta base of theta (default = 10000) + */ + void precompute_freqs_ernie(int head_dim, unsigned int seq_len, float theta); + + /** + * @brief apply rotary embedding for ERNIE + * @param[in] in input tensor + * @param[out] out output tensor + * @param[in] dim hidden dim size + * @param[in] from sequence order + * @param[in] convert_only - conversion only + */ + void apply_rotary_emb_tensor_ernie(nntrainer::Tensor &in, + nntrainer::Tensor &out, unsigned int dim, + unsigned int from, bool convert_only); + size_t calc_attn_index(size_t i); }; // end of class MHACoreLayer diff --git a/Applications/CausalLM/meson.build b/Applications/CausalLM/meson.build index e364b90f4a..5a2239fb0e 100644 --- a/Applications/CausalLM/meson.build +++ b/Applications/CausalLM/meson.build @@ -83,4 +83,4 @@ e = executable('nntr_causallm', executable_src, include_directories: causallm_inc, dependencies: [nntrainer_dep, nntrainer_ccapi_dep, causallm_layer_dependencies, causallm_dep], -) \ No newline at end of file +) From 40ecec908e0e8754e294689f1135ccc6e9c3204f Mon Sep 17 00:00:00 2001 From: Donghak PARK Date: Thu, 4 Dec 2025 13:55:45 +0900 Subject: [PATCH 6/6] [CausalLM] Avoid Race condition on evict experts Avoid Race Condition on eviction experts **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK --- Applications/CausalLM/layers/ernie_moe_layer.cpp | 6 +++++- Applications/CausalLM/layers/gpt_oss_moe_layer_cached.cpp | 5 ++++- Applications/CausalLM/layers/qwen_moe_layer_cached.cpp | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/Applications/CausalLM/layers/ernie_moe_layer.cpp b/Applications/CausalLM/layers/ernie_moe_layer.cpp index 44f8b62746..5a6180b436 100644 --- a/Applications/CausalLM/layers/ernie_moe_layer.cpp +++ b/Applications/CausalLM/layers/ernie_moe_layer.cpp @@ -458,10 +458,14 @@ void ErnieMoELayer::incremental_forwarding(nntrainer::RunLayerContext &context, // Evict experts #pragma omp parallel - while (loaded_expert_deque.size() > 16) { + while (true) { int target_idx; { std::lock_guard lock(cache_mutex); + + if (loaded_expert_deque.size() > 16) + break; + target_idx = loaded_expert_deque.front(); loaded_expert_deque.pop_front(); iteration_map.erase(target_idx); diff --git a/Applications/CausalLM/layers/gpt_oss_moe_layer_cached.cpp b/Applications/CausalLM/layers/gpt_oss_moe_layer_cached.cpp index 52c90bed02..679ebb2148 100644 --- a/Applications/CausalLM/layers/gpt_oss_moe_layer_cached.cpp +++ b/Applications/CausalLM/layers/gpt_oss_moe_layer_cached.cpp @@ -378,10 +378,13 @@ void CachedSlimGptOssMoELayer::incremental_forwarding( // Evict experts #pragma omp parallel - while (loaded_expert_deque.size() > 16) { + while (true) { int target_idx; { std::lock_guard lock(cache_mutex); + if (loaded_expert_deque.size() > 16) + break; + target_idx = loaded_expert_deque.front(); loaded_expert_deque.pop_front(); iteration_map.erase(target_idx); diff --git a/Applications/CausalLM/layers/qwen_moe_layer_cached.cpp b/Applications/CausalLM/layers/qwen_moe_layer_cached.cpp index 2a7e415b25..2b75e1823a 100644 --- a/Applications/CausalLM/layers/qwen_moe_layer_cached.cpp +++ b/Applications/CausalLM/layers/qwen_moe_layer_cached.cpp @@ -438,10 +438,12 @@ void CachedSlimMoELayer::incremental_forwarding( // Evict experts #pragma omp parallel - while (loaded_expert_deque.size() > 32) { + while (true) { int target_idx; { std::lock_guard lock(cache_mutex); + if (loaded_expert_deque.size() > 16) + break; target_idx = loaded_expert_deque.front(); loaded_expert_deque.pop_front(); iteration_map.erase(target_idx);