Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions Applications/CausalLM/ernie_causallm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// SPDX-License-Identifier: Apache-2.0
/**
*
* @file ernie_causallm.cpp
* @brief ernie 4.5 causallm header
* @date 02 December 2025
* @see https://github.com/nnstreamer/nntrainer
* @author Donghak Park <[email protected]>
* @bug No known bugs except for NYI items
*/

#include <app_context.h>
#include <engine.h>
#include <ernie_causallm.h>
#include <ernie_moe_layer.h>
#include <llm_util.hpp>

namespace causallm {

std::vector<LayerHandle>
Ernie4_5_MoeForCausalLM::createMlp(const int layer_id, int dim, int hidden_dim,
std::string input_name) {
std::vector<LayerHandle> layers;
if (layer_id == 0) {
int ffn_hidden_dim = INTERMEDIATE_SIZE; // Ernie's first layer

layers.push_back(createLayer(
"fully_connected",
{withKey("name", "layer" + std::to_string(layer_id) + "_ffn_up"),
withKey("unit", ffn_hidden_dim), withKey("disable_bias", "true"),
withKey("input_layers", input_name),
withKey("weight_initializer", "ones")}));

layers.push_back(createLayer(
"fully_connected",
{withKey("name", "layer" + std::to_string(layer_id) + "_ffn_gate"),
withKey("unit", ffn_hidden_dim), withKey("disable_bias", "true"),
withKey("input_layers", input_name),
withKey("weight_initializer", "ones")}));

layers.push_back(createLayer(
"swiglu",
{withKey("name", "layer" + std::to_string(layer_id) + "_ffn_swiglu"),
withKey("input_layers", "layer" + std::to_string(layer_id) + "_ffn_up," +
"layer" + std::to_string(layer_id) +
"_ffn_gate")}));

layers.push_back(createLayer(
"fully_connected",
{withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"),
withKey("unit", dim), withKey("disable_bias", "true"),
withKey("input_layers",
"layer" + std::to_string(layer_id) + "_ffn_swiglu"),
withKey("weight_initializer", "ones")}));

} else {
layers.push_back(createLayer(
"ernie_moe",
{withKey("name", "layer" + std::to_string(layer_id) + "_ffn_down"),
withKey("input_layers", input_name),
withKey("unit", MOE_INTERMEDIATE_SIZE),
withKey("num_experts", NUM_EXPERTS),
withKey("num_shared_experts", NUM_SHARED_EXPERTS),
withKey("num_experts_per_token", NUM_EXPERTS_PER_TOK),
withKey("moe_norm_min", std::to_string(MOE_NORM_MIN)),
withKey("moe_activation", "swish")}));
}
return layers;
}
std::vector<LayerHandle> Ernie4_5_MoeForCausalLM::createAttention(
const int layer_id, int seq_len, int n_heads, int head_dim,
std::string query_name, std::string key_name, std::string value_name) {

std::vector<LayerHandle> layers;
auto Q = "layer" + std::to_string(layer_id) + "_wq";
auto K = "layer" + std::to_string(layer_id) + "_wk";
auto V = "layer" + std::to_string(layer_id) + "_wv";
auto A = "layer" + std::to_string(layer_id) + "_attention";
auto O = "layer" + std::to_string(layer_id) + "_attention_out";

// V layer
std::vector<std::string> v_params = {
withKey("name", V), withKey("unit", head_dim * n_heads / GQA_SIZE),
withKey("disable_bias", "true"), withKey("input_layers", value_name),
withKey("weight_initializer", "ones")};
layers.push_back(createLayer("fully_connected", v_params));

// K layer
std::vector<std::string> k_params = {
withKey("name", K), withKey("unit", head_dim * n_heads / GQA_SIZE),
withKey("disable_bias", "true"), withKey("input_layers", key_name),
withKey("weight_initializer", "ones")};
layers.push_back(createLayer("fully_connected", k_params));

// Q layer
std::vector<std::string> q_params = {
withKey("name", Q), withKey("unit", head_dim * n_heads),
withKey("disable_bias", "true"), withKey("input_layers", query_name),
withKey("weight_initializer", "ones")};
layers.push_back(createLayer("fully_connected", q_params));

// Attention core layer
std::vector<std::string> a_params = {
withKey("name", A),
withKey("num_heads", n_heads),
withKey("num_heads_kv", n_heads / GQA_SIZE),
withKey("max_timestep", std::to_string(INIT_SEQ_LEN + NUM_TO_GENERATE)),
withKey("sliding_window", SLIDING_WINDOW),
withKey("rope_theta", ROPE_THETA),
withKey("max_position_embeddings", MAX_POSITION_EMBEDDINGS),
withKey("max_new_tokens", std::to_string(NUM_TO_GENERATE)),
withKey("input_layers", {Q, K, V})};
layers.push_back(createLayer("mha_core", a_params));

// O layer
std::vector<std::string> o_params = {
withKey("name", O), withKey("unit", DIM), withKey("disable_bias", "true"),
withKey("input_layers", A), withKey("weight_initializer", "ones")};
layers.push_back(createLayer("fully_connected", o_params));

return layers;
}

void Ernie4_5_MoeForCausalLM::setupParameters(json &cfg, json &generation_cfg,
json &nntr_cfg) {

try {
NUM_EXPERTS = cfg["moe_num_experts"].get<unsigned int>();
NUM_EXPERTS_PER_TOK = cfg["num_experts_per_tok"].get<unsigned int>();
MOE_INTERMEDIATE_SIZE = cfg["moe_intermediate_size"].get<unsigned int>();
INTERMEDIATE_SIZE = cfg["moe_intermediate_size"].get<unsigned int>();
NUM_SHARED_EXPERTS = cfg["moe_num_shared_experts"].get<unsigned int>();
MOE_NORM_MIN =
cfg.contains("moe_norm_min") ? cfg["moe_norm_min"].get<float>() : 1e-12f;

} catch (const std::exception &e) {
throw std::runtime_error("Ernie Causallm: config parsing error");
}
}

void Ernie4_5_MoeForCausalLM::registerCustomLayers() {
CausalLM::registerCustomLayers();
auto &ct_engine = nntrainer::Engine::Global();
auto app_context =
static_cast<nntrainer::AppContext *>(ct_engine.getRegisteredContext("cpu"));

try {
app_context->registerFactory(
nntrainer::createLayer<causallm::ErnieMoELayer>);
} catch (std::invalid_argument &e) {
std::cerr << "failed to register factory, reason: " << e.what()
<< std::endl;
}
}

} // namespace causallm
89 changes: 89 additions & 0 deletions Applications/CausalLM/ernie_causallm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// SPDX-License-Identifier: Apache-2.0
/**
*
* @file ernie_causallm.h
* @brief ernie 4.5 causallm header
* @date 02 December 2025
* @see https://github.com/nnstreamer/nntrainer
* @author Donghak Park <[email protected]>
* @bug No known bugs except for NYI items
*/

#ifndef NNTRAINER_ERNIE_CAUSALLM_H
#define NNTRAINER_ERNIE_CAUSALLM_H
#include <causal_lm.h>

namespace causallm {

/**
* @class Ernie4_5_MoeForCausalLM
* @brief Mixture of Expert Layer for ERNIE 4.5
*/
class Ernie4_5_MoeForCausalLM : public CausalLM {
public:
static constexpr const char *architecture = "Ernie4_5_MoeForCausalLM";
Ernie4_5_MoeForCausalLM(json &cfg, json &generation_cfg, json &nntr_cfg) :
CausalLM(cfg, generation_cfg, nntr_cfg) {
setupParameters(cfg, generation_cfg, nntr_cfg);
}

virtual ~Ernie4_5_MoeForCausalLM() = default;

/**
* @brief MoE layer
*/
/**
* @brief Create MLP layer
* @param layer_id Layer ID
* @param dim Dimension
* @param hidden_dim Hidden dimension
* @param input_name Input name
* @return std::vector<LayerHandle> Vector of layer handles
*/
std::vector<LayerHandle> createMlp(const int layer_id, int dim,
int hidden_dim,
std::string input_name) override;

/**
* @brief Create Attention layer
* @param layer_id Layer ID
* @param seq_len Sequence length
* @param n_heads Number of heads
* @param head_dim Head dimension
* @param query_name Query name
* @param key_name Key name
* @param value_name Value name
* @return std::vector<LayerHandle> Vector of layer handles
*/
std::vector<LayerHandle> createAttention(int layer_id, int seq_len,
int n_heads, int head_dim,
std::string query_name,
std::string key_name,
std::string value_name) override;
/**
* @brief Setup parameters for the model
* @param cfg Configuration json
* @param generation_cfg Generation configuration json
* @param nntr_cfg NNtrainer configuration json
*/
void setupParameters(json &cfg, json &generation_cfg,
json &nntr_cfg) override;

/**
* @brief Register custom layers
*/
void registerCustomLayers() override;

private:
unsigned int NUM_EXPERTS; /**< Number of experts */
unsigned int NUM_EXPERTS_PER_TOK; /**< Number of experts per token */
unsigned int NUM_SHARED_EXPERTS; /**< Number of shared experts */
unsigned int MOE_INTERMEDIATE_SIZE; /**< MoE intermediate size */
float MOE_NORM_MIN; /**< MoE normalization minimum */

std::vector<std::string> LAYER_TYPES; /**< Layer types */
float ATTENTION_ROPE_SCALING_FACTOR; /**< Attention RoPE scaling factor */
};

} // namespace causallm
#endif // NNTRAINER_ERNIE_CAUSALLM_H
18 changes: 18 additions & 0 deletions Applications/CausalLM/layers/causallm_common_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,24 @@ class NumExpertsPerToken : public nntrainer::PositiveIntegerProperty {
using prop_tag = nntrainer::uint_prop_tag; /**< property type */
};

/**
* @brief NumSharedExpers, Number of shared experts property
*/
class NumSharedExperts : public nntrainer::Property<unsigned int> {
public:
static constexpr const char *key = "num_shared_experts";
using prop_tag = nntrainer::uint_prop_tag;
};

/**
* @brief MoENormMin, Minimum value for MoE normalization
*/
class MoENormMin : public nntrainer::Property<float> {
public:
static constexpr const char *key = "moe_norm_min";
using prop_tag = nntrainer::float_prop_tag;
};

/**
* @brief unit property, unit is used to measure how many weights are there
*
Expand Down
Loading
Loading