diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 4e6c0f60c0621..7839875159cd9 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -687,6 +687,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
         if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1":
             # ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct
             res = "megrez"
+        if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
+            # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
+            res = "deepseek-v3"
 
         if res is None:
             logger.warning("\n")
@@ -3831,6 +3834,7 @@ def prepare_tensors(self):
 
 
 @Model.register("DeepseekV2ForCausalLM")
+@Model.register("DeepseekV3ForCausalLM")
 class DeepseekV2Model(Model):
     model_arch = gguf.MODEL_ARCH.DEEPSEEK2
 
@@ -3852,6 +3856,15 @@ def set_gguf_parameters(self):
         self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
         self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
         self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
+        self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
+
+        if hparams["scoring_func"] == "sigmoid":
+            self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
+        elif hparams["scoring_func"] == "softmax":
+            self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
+        else:
+            raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
+
         self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
 
         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
@@ -3864,6 +3877,16 @@ def set_gguf_parameters(self):
     _experts: list[dict[str, Tensor]] | None = None
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # rename e_score_correction_bias tensors
+        if name.endswith("e_score_correction_bias"):
+            name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+
+        # skip Multi-Token Prediction (MTP) layers
+        block_count = self.hparams["num_hidden_layers"]
+        match = re.match(r"model.layers.(\d+)", name)
+        if match and int(match.group(1)) >= block_count:
+            return []
+
         # process the experts separately
         if name.find("mlp.experts") != -1:
             n_experts = self.hparams["n_routed_experts"]
diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py
index fea23ddb4ae48..56edc64a72761 100755
--- a/convert_hf_to_gguf_update.py
+++ b/convert_hf_to_gguf_update.py
@@ -107,6 +107,7 @@ class TOKENIZER_TYPE(IntEnum):
     {"name": "roberta-bpe",    "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
     {"name": "gigachat",       "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
     {"name": "megrez",         "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
+    {"name": "deepseek-v3",    "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
 ]
 
 
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 273370370e6ca..ef795c04e1ca5 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -102,6 +102,8 @@ class LLM:
         EXPERT_USED_COUNT                 = "{arch}.expert_used_count"
         EXPERT_SHARED_COUNT               = "{arch}.expert_shared_count"
         EXPERT_WEIGHTS_SCALE              = "{arch}.expert_weights_scale"
+        EXPERT_WEIGHTS_NORM               = "{arch}.expert_weights_norm"
+        EXPERT_GATING_FUNC                = "{arch}.expert_gating_func"
         POOLING_TYPE                      = "{arch}.pooling_type"
         LOGIT_SCALE                       = "{arch}.logit_scale"
         DECODER_START_TOKEN_ID            = "{arch}.decoder_start_token_id"
@@ -312,6 +314,7 @@ class MODEL_TENSOR(IntEnum):
     FFN_GATE_SHEXP       = auto()
     FFN_DOWN_SHEXP       = auto()
     FFN_UP_SHEXP         = auto()
+    FFN_EXP_PROBS_B      = auto()
     ATTN_Q_NORM          = auto()
     ATTN_K_NORM          = auto()
     LAYER_OUT_NORM       = auto()
@@ -496,6 +499,7 @@ class MODEL_TENSOR(IntEnum):
     MODEL_TENSOR.FFN_GATE_EXP:              "blk.{bid}.ffn_gate_exps",
     MODEL_TENSOR.FFN_DOWN_EXP:              "blk.{bid}.ffn_down_exps",
     MODEL_TENSOR.FFN_UP_EXP:                "blk.{bid}.ffn_up_exps",
+    MODEL_TENSOR.FFN_EXP_PROBS_B:           "blk.{bid}.exp_probs_b",
     MODEL_TENSOR.LAYER_OUT_NORM:            "blk.{bid}.layer_output_norm",
     MODEL_TENSOR.SSM_IN:                    "blk.{bid}.ssm_in",
     MODEL_TENSOR.SSM_CONV1D:                "blk.{bid}.ssm_conv1d",
@@ -1276,6 +1280,7 @@ class MODEL_TENSOR(IntEnum):
         MODEL_TENSOR.FFN_GATE_SHEXP,
         MODEL_TENSOR.FFN_DOWN_SHEXP,
         MODEL_TENSOR.FFN_UP_SHEXP,
+        MODEL_TENSOR.FFN_EXP_PROBS_B,
     ],
     MODEL_ARCH.CHATGLM : [
         MODEL_TENSOR.TOKEN_EMBD,
@@ -1576,6 +1581,11 @@ class GGMLQuantizationType(IntEnum):
     TQ2_0   = 35
 
 
+class ExpertGatingFuncType(IntEnum):
+    SOFTMAX  = 1
+    SIGMOID  = 2
+
+
 # TODO: add GGMLFileType from ggml_ftype in ggml.h
 
 
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index 3023b539ae82b..4a0a65e3cc33e 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -26,6 +26,7 @@
     RopeScalingType,
     PoolingType,
     TokenType,
+    ExpertGatingFuncType,
 )
 
 from .quants import quant_shape_from_byte_shape
@@ -715,6 +716,12 @@ def add_expert_shared_count(self, count: int) -> None:
     def add_expert_weights_scale(self, value: float) -> None:
         self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
 
+    def add_expert_weights_norm(self, value: bool) -> None:
+        self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
+
+    def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
+        self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
+
     def add_swin_norm(self, value: bool) -> None:
         self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
 
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 7009a11d46bc8..efe2a4aa4fe28 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -276,6 +276,10 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
         ),
 
+        MODEL_TENSOR.FFN_EXP_PROBS_B: (
+            "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
+        ),
+
         # Feed-forward up
         MODEL_TENSOR.FFN_UP: (
             "gpt_neox.layers.{bid}.mlp.dense_h_to_4h",                # gptneox
diff --git a/include/llama.h b/include/llama.h
index 7b305b299c195..a0d5ba5ddcfd2 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -105,6 +105,7 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_EXAONE         = 25,
         LLAMA_VOCAB_PRE_TYPE_CHAMELEON      = 26,
         LLAMA_VOCAB_PRE_TYPE_MINERVA        = 27,
+        LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM  = 28,
     };
 
     enum llama_rope_type {
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index a6003838559a7..6ff91b6c262cf 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -91,6 +91,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_EXPERT_USED_COUNT,                 "%s.expert_used_count"                 },
     { LLM_KV_EXPERT_SHARED_COUNT,               "%s.expert_shared_count"               },
     { LLM_KV_EXPERT_WEIGHTS_SCALE,              "%s.expert_weights_scale"              },
+    { LLM_KV_EXPERT_WEIGHTS_NORM,               "%s.expert_weights_norm"               },
+    { LLM_KV_EXPERT_GATING_FUNC,                "%s.expert_gating_func"                },
     { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
@@ -968,6 +970,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
             { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
             { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
+            { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
         },
     },
     {
@@ -1350,6 +1353,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_FFN_DOWN_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_GATE_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_UP_EXPS,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
+    {LLM_TENSOR_FFN_EXP_PROBS_B,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     // this tensor is loaded for T5, but never used
     {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
     {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 446e72eebf6d6..72588c46cffa6 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -95,6 +95,8 @@ enum llm_kv {
     LLM_KV_EXPERT_USED_COUNT,
     LLM_KV_EXPERT_SHARED_COUNT,
     LLM_KV_EXPERT_WEIGHTS_SCALE,
+    LLM_KV_EXPERT_WEIGHTS_NORM,
+    LLM_KV_EXPERT_GATING_FUNC,
     LLM_KV_POOLING_TYPE,
     LLM_KV_LOGIT_SCALE,
     LLM_KV_DECODER_START_TOKEN_ID,
@@ -230,6 +232,7 @@ enum llm_tensor {
     LLM_TENSOR_FFN_DOWN_SHEXP,
     LLM_TENSOR_FFN_GATE_SHEXP,
     LLM_TENSOR_FFN_UP_SHEXP,
+    LLM_TENSOR_FFN_EXP_PROBS_B,
     LLM_TENSOR_ATTN_Q_NORM,
     LLM_TENSOR_ATTN_K_NORM,
     LLM_TENSOR_LAYER_OUT_NORM,
diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp
index a07e9cf00b942..44670d3d83934 100644
--- a/src/llama-chat.cpp
+++ b/src/llama-chat.cpp
@@ -45,6 +45,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "vicuna-orca",       LLM_CHAT_TEMPLATE_VICUNA_ORCA       },
     { "deepseek",          LLM_CHAT_TEMPLATE_DEEPSEEK          },
     { "deepseek2",         LLM_CHAT_TEMPLATE_DEEPSEEK_2        },
+    { "deepseek3",         LLM_CHAT_TEMPLATE_DEEPSEEK_3        },
     { "command-r",         LLM_CHAT_TEMPLATE_COMMAND_R         },
     { "llama3",            LLM_CHAT_TEMPLATE_LLAMA_3           },
     { "chatglm3",          LLM_CHAT_TEMPLATE_CHATGML_3         },
@@ -148,6 +149,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_MINICPM;
     } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
         return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
+    } else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) {
+        return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
     } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
         // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
         // EXAONE-3.0-7.8B-Instruct
@@ -453,6 +456,21 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "Assistant:";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) {
+        // DeepSeek-V3
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << message->content << "\n\n";
+            } else if (role == "user") {
+                ss << LU8("<|User|>") << message->content;
+            } else if (role == "assistant") {
+                ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>");
+            }
+        }
+        if (add_ass) {
+            ss << LU8("<|Assistant|>");
+        }
     } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
         // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
         // EXAONE-3.0-7.8B-Instruct
diff --git a/src/llama-chat.h b/src/llama-chat.h
index 364318c2775db..b8e94d9ef2b3b 100644
--- a/src/llama-chat.h
+++ b/src/llama-chat.h
@@ -25,6 +25,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_VICUNA_ORCA,
     LLM_CHAT_TEMPLATE_DEEPSEEK,
     LLM_CHAT_TEMPLATE_DEEPSEEK_2,
+    LLM_CHAT_TEMPLATE_DEEPSEEK_3,
     LLM_CHAT_TEMPLATE_COMMAND_R,
     LLM_CHAT_TEMPLATE_LLAMA_3,
     LLM_CHAT_TEMPLATE_CHATGML_3,
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index 3a76b71a4e18e..a29f20ec49665 100644
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
@@ -6,7 +6,13 @@
 
 // bump if necessary
 #define LLAMA_MAX_LAYERS  512
-#define LLAMA_MAX_EXPERTS 160  // DeepSeekV2
+#define LLAMA_MAX_EXPERTS 256  // DeepSeekV3
+
+enum llama_expert_gating_func_type {
+    LLAMA_EXPERT_GATING_FUNC_TYPE_NONE    = 0,
+    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
+    LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
+};
 
 struct llama_hparams_posnet {
     uint32_t n_embd;
@@ -54,7 +60,9 @@ struct llama_hparams {
     uint32_t n_expert_shared    = 0;
     uint32_t n_norm_groups      = 0;
 
-    float expert_weights_scale = 0.0;
+    float    expert_weights_scale = 0.0;
+    bool     expert_weights_norm  = false;
+    uint32_t expert_gating_func   = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
 
     float f_norm_eps;
     float f_norm_rms_eps;
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index ace0ba2629f68..d41e1b953118a 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -66,6 +66,7 @@ const char * llm_type_name(llm_type type) {
         case MODEL_70B:           return "70B";
         case MODEL_236B:          return "236B";
         case MODEL_314B:          return "314B";
+        case MODEL_671B:          return "671B";
         case MODEL_SMALL:         return "0.1B";
         case MODEL_MEDIUM:        return "0.4B";
         case MODEL_LARGE:         return "0.8B";
@@ -125,6 +126,14 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
     }
 }
 
+static const char * llama_expert_gating_func_name(llama_expert_gating_func_type type) {
+    switch (type) {
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax";
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid";
+        default:                                    return "unknown";
+    }
+}
+
 std::string llama_model_arch_name (const llama_model & model) {
     return llm_arch_name(model.arch);
 }
@@ -923,11 +932,19 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
                 ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
                 ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
+                ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
+                if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
+                    // for compatibility with existing DeepSeek V2 and V2.5 GGUFs
+                    // that have no expert_gating_func model parameter set
+                    hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
+                }
                 ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
 
                 switch (hparams.n_layer) {
                     case 27: model.type = e_model::MODEL_16B; break;
                     case 60: model.type = e_model::MODEL_236B; break;
+                    case 61: model.type = e_model::MODEL_671B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
@@ -1249,6 +1266,10 @@ void llm_load_vocab(llama_model_loader & ml, llama_model & model) {
                     tokenizer_pre == "deepseek-coder") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
                 vocab.tokenizer_clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "deepseek-v3") {
+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
+                vocab.tokenizer_clean_spaces = false;
             } else if (
                     tokenizer_pre == "falcon") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
@@ -1931,6 +1952,8 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
+        LLAMA_LOG_INFO("%s: expert_weights_norm  = %d\n",     __func__, hparams.expert_weights_norm);
+        LLAMA_LOG_INFO("%s: expert_gating_func   = %s\n",     __func__, llama_expert_gating_func_name((enum llama_expert_gating_func_type) hparams.expert_gating_func));
         LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
     }
 
diff --git a/src/llama-model.h b/src/llama-model.h
index 01c780c4182b3..ce038932d4e2d 100644
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -63,6 +63,7 @@ enum llm_type {
     MODEL_70B,
     MODEL_236B,
     MODEL_314B,
+    MODEL_671B,
     MODEL_SMALL,
     MODEL_MEDIUM,
     MODEL_LARGE,
@@ -213,6 +214,7 @@ struct llama_layer {
     struct ggml_tensor * ffn_down_b = nullptr; // b2
     struct ggml_tensor * ffn_up_b   = nullptr; // b3
     struct ggml_tensor * ffn_act    = nullptr;
+    struct ggml_tensor * ffn_exp_probs_b = nullptr;
 
     // mamba proj
     struct ggml_tensor * ssm_in  = nullptr;
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 909e04871949e..3fcfcaa3f4117 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -382,6 +382,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "\\p{N}+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM:
+                regex_exprs = {
+                    "\\p{N}{1,3}",
+                    "[一-龥぀-ゟ゠-ヿ]+",
+                    "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
             case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
                 regex_exprs = {
                     "[\r\n]",
diff --git a/src/llama.cpp b/src/llama.cpp
index d7110b90bcce0..30aa2042334fc 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -1831,6 +1831,7 @@ static bool llm_load_tensors(
                             layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                         } else {
                             layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                             if (n_expert == 0) {
                                 throw std::runtime_error("n_expert must be > 0");
@@ -2811,12 +2812,14 @@ static struct ggml_tensor * llm_build_moe_ffn(
          struct ggml_tensor * up_exps,
          struct ggml_tensor * gate_exps,
          struct ggml_tensor * down_exps,
+         struct ggml_tensor * exp_probs_b,
                     int64_t   n_expert,
                     int64_t   n_expert_used,
             llm_ffn_op_type   type_op,
                        bool   norm_w,
                        bool   scale_w,
                       float   w_scale,
+llama_expert_gating_func_type gating_op,
          const llm_build_cb & cb,
                         int   il) {
     int64_t n_embd = cur->ne[0];
@@ -2825,11 +2828,31 @@ static struct ggml_tensor * llm_build_moe_ffn(
     ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
     cb(logits, "ffn_moe_logits", il);
 
-    ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
+    ggml_tensor * probs = nullptr;
+    switch (gating_op) {
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
+            {
+                probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
+            } break;
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
+            {
+                probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
     cb(probs, "ffn_moe_probs", il);
 
+    // add experts selection bias - introduced in DeepSeek V3
+    // leave probs unbiased as it's later used to get expert weights
+    ggml_tensor * selection_probs = probs;
+    if (exp_probs_b != nullptr) {
+        selection_probs = ggml_add(ctx, probs, exp_probs_b);
+        cb(selection_probs, "ffn_moe_probs_biased", il);
+    }
+
     // select experts
-    ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
+    ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
     cb(selected_experts->src[0], "ffn_moe_argsort", il);
     cb(selected_experts, "ffn_moe_topk", il);
 
@@ -3950,9 +3973,11 @@ struct llm_build_context {
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
                         model.layers[il].ffn_down_exps,
+                        nullptr,
                         n_expert, n_expert_used,
                         LLM_FFN_SILU, true,
                         false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                         cb, il);
                 cb(cur, "ffn_moe_out", il);
             }
@@ -4602,9 +4627,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_GELU, true,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -4743,9 +4770,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -5991,9 +6020,11 @@ struct llm_build_context {
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
                         model.layers[il].ffn_down_exps,
+                        nullptr,
                         n_expert, n_expert_used,
                         LLM_FFN_SILU, false,
                         false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                         cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -7985,9 +8016,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, false,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -8382,9 +8415,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -8523,9 +8558,11 @@ struct llm_build_context {
                             model.layers[il].ffn_up_exps,
                             model.layers[il].ffn_gate_exps,
                             model.layers[il].ffn_down_exps,
+                            nullptr,
                             n_expert, n_expert_used,
                             LLM_FFN_SILU, false,
                             false, hparams.expert_weights_scale,
+                            LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                             cb, il);
                 cb(moe_out, "ffn_moe_out", il);
 
@@ -8752,9 +8789,11 @@ struct llm_build_context {
                             model.layers[il].ffn_up_exps,
                             model.layers[il].ffn_gate_exps,
                             model.layers[il].ffn_down_exps,
+                            model.layers[il].ffn_exp_probs_b,
                             n_expert, n_expert_used,
-                            LLM_FFN_SILU, false,
+                            LLM_FFN_SILU, hparams.expert_weights_norm,
                             true, hparams.expert_weights_scale,
+                            (enum llama_expert_gating_func_type) hparams.expert_gating_func,
                             cb, il);
                 cb(moe_out, "ffn_moe_out", il);
 
diff --git a/src/unicode.cpp b/src/unicode.cpp
index 8ed6b1a51c251..7aca6544bc73d 100644
--- a/src/unicode.cpp
+++ b/src/unicode.cpp
@@ -667,18 +667,24 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
         { "\\p{N}", unicode_cpt_flags::NUMBER },
         { "\\p{L}", unicode_cpt_flags::LETTER },
         { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
+        { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
+        { "\\p{S}", unicode_cpt_flags::SYMBOL },
     };
 
     static const std::map<int, int> k_ucat_cpt = {
         { unicode_cpt_flags::NUMBER,      0xD1 },
         { unicode_cpt_flags::LETTER,      0xD2 },
         { unicode_cpt_flags::PUNCTUATION, 0xD3 },
+        { unicode_cpt_flags::ACCENT_MARK, 0xD4 },
+        { unicode_cpt_flags::SYMBOL,      0xD5 },
     };
 
     static const std::map<int, std::string> k_ucat_map = {
         { unicode_cpt_flags::NUMBER,      "\x30-\x39" }, // 0-9
         { unicode_cpt_flags::LETTER,      "\x41-\x5A\x61-\x7A" }, // A-Za-z
         { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
+        { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
+        { unicode_cpt_flags::SYMBOL,      "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
     };
 
     // compute collapsed codepoints only if needed by at least one regex