Skip to content

llama: Attempt to add ModernBert #14014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
045b1ac
llama: attempt to add modern-bert
huydt-bti Jun 4, 2025
95f49d9
Merge branch 'master' into huydt/mb
huydt-bti Jun 4, 2025
eab776e
re-format and delete unused implementations
huydt-bti Jun 4, 2025
7143840
overload set_swa_pattern for modern bert
huydt-bti Jun 4, 2025
6aa1335
modern-bert doesn't have bias
huydt-bti Jun 4, 2025
9e1179a
delete unnecessary files
huydt-bti Jun 4, 2025
fa23480
add build_attn_inp_no_cache_iswa with symmetric swa
huydt-bti Jun 5, 2025
a72cb3b
add modern-bert to llama_model::create_memory
huydt-bti Jun 5, 2025
adea1c9
fix lint
huydt-bti Jun 5, 2025
cfebb6e
access n_swa via hparams
huydt-bti Jun 5, 2025
31e87e4
revert changes in convert script
huydt-bti Jun 6, 2025
1004327
add set_vocab to modernbert convert class
huydt-bti Jun 6, 2025
81f4797
Merge branch 'master' into huydt/mb
huydt-bti Jun 6, 2025
03693fa
parmas-related fix
huydt-bti Jun 6, 2025
2f5a72f
handle mask token in modern-bert bpe
huydt-bti Jun 6, 2025
68f399e
add modern-bert to pre_type check
huydt-bti Jun 6, 2025
ad2a19a
change log warning when no mask token of modern-bert
huydt-bti Jun 6, 2025
c6b84e2
fix modern-bert swa logic
huydt-bti Jun 8, 2025
6751e69
fix modern-bert class register
huydt-bti Jun 8, 2025
8b794f9
Merge branch 'master' into huydt/mb
huydt-bti Jun 8, 2025
4d6b804
handle when no cls_b and cls_out_b
huydt-bti Jun 9, 2025
5821d6c
fix unnecessary operations
huydt-bti Jun 9, 2025
820cee1
revert incorrect change
huydt-bti Jun 9, 2025
4fc4bf6
Merge branch 'master' into huydt/mb
huydt-bti Jun 9, 2025
16b73d4
use build_ffn with LLM_FFN_GEGLU
huydt-bti Jun 9, 2025
333eeed
fix for use of models without n_swa
huydt-bti Jun 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152":
# ref: https://huggingface.co/answerdotai/ModernBERT-base
res = "modern-bert"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -3932,6 +3935,34 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("ModernBert", "ModernBertForMaskedLM", "ModernBertForSequenceClassification")
class ModernBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.MODERN_BERT

def set_vocab(self):
self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_sliding_window(self.hparams["local_attention"])
self.gguf_writer.add_rope_freq_base(self.hparams["global_rope_theta"])
self.gguf_writer.add_rope_freq_base_swa(self.hparams["local_rope_theta"])
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# These layers act as MLM head, so we don't need them
if name.startswith("decoder."):
return []

if name.startswith("model."):
name = name[6:]

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("RobertaModel", "RobertaForSequenceClassification")
class RobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
{"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/answerdotai/ModernBERT-base", },
]

# some models are known to be broken upstream, so we will skip them as exceptions
Expand Down
19 changes: 19 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
Expand Down Expand Up @@ -289,6 +290,7 @@ class MODEL_ARCH(IntEnum):
STARCODER = auto()
REFACT = auto()
BERT = auto()
MODERN_BERT = auto()
NOMIC_BERT = auto()
NOMIC_BERT_MOE = auto()
JINA_BERT_V2 = auto()
Expand Down Expand Up @@ -477,6 +479,7 @@ class MODEL_TENSOR(IntEnum):
ENC_FFN_UP = auto()
ENC_OUTPUT_NORM = auto()
CLS = auto() # classifier
CLS_NORM = auto() # classifier normalization
CLS_OUT = auto() # classifier output projection
CONV1D = auto()
CONVNEXT_DW = auto()
Expand Down Expand Up @@ -569,6 +572,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.STARCODER: "starcoder",
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.MODERN_BERT: "modern-bert",
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
Expand Down Expand Up @@ -757,6 +761,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
MODEL_TENSOR.CLS: "cls",
MODEL_TENSOR.CLS_NORM: "cls.norm",
MODEL_TENSOR.CLS_OUT: "cls.output",
MODEL_TENSOR.CONV1D: "conv1d",
MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw",
Expand Down Expand Up @@ -1047,6 +1052,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.CLS,
MODEL_TENSOR.CLS_OUT,
],
MODEL_ARCH.MODERN_BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
MODEL_TENSOR.CLS,
MODEL_TENSOR.CLS_NORM,
MODEL_TENSOR.CLS_OUT,
],
MODEL_ARCH.NOMIC_BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,9 @@ def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
def add_rope_freq_base(self, value: float) -> None:
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)

def add_rope_freq_base_swa(self, value: float) -> None:
self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value)

def add_rope_scaling_type(self, value: RopeScalingType) -> None:
self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)

Expand Down
14 changes: 14 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class TensorNameMap:
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"embeddings.tok_embeddings", # modern-bert
"language_model.embedding.word_embeddings", # persimmon
"wte", # gpt2
"transformer.embd.wte", # phi2
Expand All @@ -42,6 +43,7 @@ class TensorNameMap:
MODEL_TENSOR.TOKEN_EMBD_NORM: (
"word_embeddings_layernorm", # bloom
"embeddings.LayerNorm", # bert
"embeddings.norm", # modern-bert
"emb_ln", # nomic-bert
"transformer.norm", # openelm
"rwkv.blocks.0.pre_ln", # rwkv
Expand Down Expand Up @@ -134,6 +136,7 @@ class TensorNameMap:
"rwkv.blocks.{bid}.ln1", # rwkv6
"model.layers.{bid}.ln1", # rwkv7
"model.layers.{bid}.input_layernorm", # llama4
"layers.{bid}.attn_norm", # modern-bert
),

# Attention norm 2
Expand Down Expand Up @@ -161,6 +164,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
"layers.{bid}.attn.Wqkv", # modern-bert
),

# Attention query
Expand Down Expand Up @@ -236,6 +240,7 @@ class TensorNameMap:
"transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone
"model.layers.{bid}.self_attn.o_proj", # llama4
"layers.{bid}.attn.Wo", # modern-bert
),

# Attention output norm
Expand All @@ -245,6 +250,7 @@ class TensorNameMap:
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
"layers.{bid}.mlp_norm" # modern-bert
),

MODEL_TENSOR.ATTN_POST_NORM: (
Expand Down Expand Up @@ -338,6 +344,7 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
"model.layers.{bid}.feed_forward.up_proj", # llama4
"layers.{bid}.mlp.Wi" # modern-bert
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -420,6 +427,7 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
"model.layers.{bid}.feed_forward.down_proj", # llama4
"layers.{bid}.mlp.Wo" # modern-bert
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down Expand Up @@ -830,12 +838,18 @@ class TensorNameMap:
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5
"final_norm", # modern-bert
),

MODEL_TENSOR.CLS: (
"classifier", # jina
"classifier.dense", # roberta
"pre_classifier", # distillbert
"head.dense", # modern-bert
),

MODEL_TENSOR.CLS_NORM: (
"head.norm", # modern-bert
),

MODEL_TENSOR.CLS_OUT: (
Expand Down
20 changes: 20 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_STARCODER, "starcoder" },
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_MODERN_BERT, "modern-bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
Expand Down Expand Up @@ -148,6 +149,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
Expand Down Expand Up @@ -462,6 +464,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CLS_OUT, "cls.output" },
},
},
{
LLM_ARCH_MODERN_BERT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
{ LLM_TENSOR_CLS, "cls" },
{ LLM_TENSOR_CLS_NORM, "cls.norm" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
},
},
{
LLM_ARCH_NOMIC_BERT,
{
Expand Down Expand Up @@ -1572,6 +1591,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum llm_arch {
LLM_ARCH_STARCODER,
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_MODERN_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_JINA_BERT_V2,
Expand Down Expand Up @@ -152,6 +153,7 @@ enum llm_kv {
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_FREQ_BASE_SWA,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
Expand Down Expand Up @@ -347,6 +349,7 @@ enum llm_tensor {
LLM_TENSOR_ENC_FFN_UP,
LLM_TENSOR_ENC_OUTPUT_NORM,
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_NORM,
LLM_TENSOR_CLS_OUT,
LLM_TENSOR_CONV1D,
LLM_TENSOR_CONVNEXT_DW,
Expand Down
1 change: 1 addition & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
llama_set_param(model->cls, param_filter, param_filter_ud);
llama_set_param(model->cls_b, param_filter, param_filter_ud);
llama_set_param(model->cls_norm, param_filter, param_filter_ud);
llama_set_param(model->cls_out, param_filter, param_filter_ud);
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);

Expand Down
80 changes: 78 additions & 2 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,61 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {

void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
if (kq_mask) {
if (cparams.causal_attn) {
// Check if we're using sliding window attention
if (n_swa > 0) {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is actually non-causal attention + sliding window. So merge it with the existing implementation below.

const int64_t n_seqs = ubatch->n_seqs;
const int64_t n_stride = ubatch->n_tokens;
const int64_t half_n_swa = n_swa / 2;

GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data;

// Implement symmetric sliding window attention
// token i attends to tokens [i - n_swa/2, i + n_swa/2]
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0];

for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
const int64_t pos_j = ubatch->pos[tj];

for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;

for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
if (ubatch->seq_id[s0][s] == seq_id) {
const int64_t pos_i = ubatch->pos[ti];
const int64_t pos_diff = pos_j - pos_i;

// Apply sliding window constraint
// [i - n_swa/2, i + n_swa/2]
if (pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
if (hparams.use_alibi) {
f = -std::abs(pos_diff);
} else {
f = 0.0f;
}
}
break;
}
}

data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
}
}

for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
}
}
}
}
} else if (cparams.causal_attn) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
Expand Down Expand Up @@ -1188,6 +1242,22 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
}

llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache_iswa() const {
// Use the sliding window size from hyperparameters
// If hparams.n_swa is 0, use a default value (128)
const int n_swa = hparams.n_swa > 0 ? hparams.n_swa : 128;

auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams, n_swa);

// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
ggml_set_input(inp->kq_mask);

inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;

return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
Expand Down Expand Up @@ -1522,7 +1592,8 @@ void llm_graph_context::build_pooling(
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const {
ggml_tensor * cls_out_b,
ggml_tensor * cls_norm) const {
if (!cparams.embeddings) {
return;
}
Expand Down Expand Up @@ -1570,6 +1641,11 @@ void llm_graph_context::build_pooling(
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
cur = ggml_tanh(ctx0, cur);

if (cls_norm) {
// normalization head
cur = build_norm(cur, cls_norm, nullptr, LLM_NORM, 0);
}

// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
if (cls_out) {
Expand Down
Loading
Loading