@@ -302,6 +302,8 @@ enum llm_kv {
302302 LLM_KV_POOLING_TYPE,
303303 LLM_KV_LOGIT_SCALE,
304304 LLM_KV_DECODER_START_TOKEN_ID,
305+ LLM_KV_ATTN_LOGIT_SOFTCAPPING,
306+ LLM_KV_FINAL_LOGIT_SOFTCAPPING,
305307
306308 LLM_KV_ATTENTION_HEAD_COUNT,
307309 LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -392,6 +394,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
392394 { LLM_KV_POOLING_TYPE , "%s.pooling_type" },
393395 { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
394396 { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
397+ { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
398+ { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
395399
396400 { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
397401 { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -2099,6 +2103,9 @@ struct llama_hparams {
20992103 float f_norm_eps;
21002104 float f_norm_rms_eps;
21012105
2106+ float f_attn_logit_softcapping = 50.0f;
2107+ float f_final_logit_softcapping = 30.0f;
2108+
21022109 float rope_attn_factor = 1.0f;
21032110 float rope_freq_base_train;
21042111 float rope_freq_scale_train;
@@ -2115,8 +2122,9 @@ struct llama_hparams {
21152122 float f_max_alibi_bias = 0.0f;
21162123 float f_logit_scale = 0.0f;
21172124
2118- bool causal_attn = true;
2119- bool use_alibi = false;
2125+ bool causal_attn = true;
2126+ bool use_alibi = false;
2127+ bool attn_soft_cap = false;
21202128
21212129 enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
21222130 enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@@ -4702,6 +4710,9 @@ static void llm_load_hparams(
47024710 case LLM_ARCH_GEMMA2:
47034711 {
47044712 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4713+ ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4714+ ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
4715+ hparams.attn_soft_cap = true;
47054716
47064717 switch (hparams.n_layer) {
47074718 case 42: model.type = e_model::MODEL_9B; break;
@@ -7579,6 +7590,12 @@ static struct ggml_tensor * llm_build_kqv(
75797590 kq = ggml_scale(ctx, kq, 30);
75807591 }
75817592
7593+ if (hparams.attn_soft_cap) {
7594+ kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
7595+ kq = ggml_tanh(ctx, kq);
7596+ kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
7597+ }
7598+
75827599 kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
75837600 cb(kq, "kq_soft_max_ext", il);
75847601
@@ -11039,7 +11056,7 @@ struct llm_build_context {
1103911056 ext_factor, attn_factor, beta_fast, beta_slow);
1104011057 cb(Qcur, "Qcur", il);
1104111058
11042- Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k )));
11059+ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head )));
1104311060 cb(Qcur, "Qcur_scaled", il);
1104411061
1104511062 Kcur = ggml_rope_ext(
@@ -11106,6 +11123,12 @@ struct llm_build_context {
1110611123
1110711124 // lm_head
1110811125 cur = ggml_mul_mat(ctx0, model.output, cur);
11126+
11127+ // final logit soft-capping
11128+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
11129+ cur = ggml_tanh(ctx0, cur);
11130+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
11131+
1110911132 cb(cur, "result_output", -1);
1111011133
1111111134 ggml_build_forward_expand(gf, cur);
@@ -17379,6 +17402,12 @@ struct llama_context * llama_new_context_with_model(
1737917402 params.flash_attn = false;
1738017403 }
1738117404
17405+ if (params.flash_attn && model->hparams.attn_soft_cap) {
17406+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
17407+ params.flash_attn = false;
17408+ }
17409+
17410+
1738217411 if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
1738317412 LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
1738417413 params.flash_attn = false;
0 commit comments