Skip to content

[WIP] add deepseek-v3 #35926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 101 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
b926c3d
init commit
ArthurZucker Sep 9, 2024
c62c5b7
Merge branch 'main' of github.com:huggingface/transformers into updat…
ArthurZucker Sep 9, 2024
5b85023
style
ArthurZucker Sep 9, 2024
3b76bda
take comments into account
ArthurZucker Sep 9, 2024
704767e
add deepseekv3 modeling
bzantium Jan 28, 2025
737ee3a
Merge branch 'main' into feature/#35425
bzantium Jan 28, 2025
fc3a4c7
Merge branch 'main' of https://github.com/bzantium/transformers into …
bzantium Jan 28, 2025
244e793
remove redundant code
bzantium Jan 28, 2025
0968df5
Merge branch 'feature/#35425' of https://github.com/bzantium/transfor…
bzantium Jan 28, 2025
4fb2a80
apply make style
bzantium Jan 28, 2025
6b002e5
apply fix-copies
bzantium Jan 28, 2025
4ec1e88
make format
bzantium Jan 28, 2025
114ab84
add init files
bzantium Jan 28, 2025
779f8d2
rename deepseekv3 into deepseek_v3 based on its model_type
bzantium Jan 28, 2025
22623a3
rename deepseekv3 into deepseek_v3 based on its model_type
bzantium Jan 28, 2025
78b19b0
deepseek-v3 not deepseek_v3
bzantium Jan 28, 2025
eb0e3a4
set model_type as deepseek_v3
bzantium Jan 28, 2025
57088cc
use default docs
bzantium Jan 28, 2025
0ef561b
apply make
bzantium Jan 28, 2025
9a75a56
fill type and docstring
bzantium Jan 28, 2025
cdf83e4
add rope_config_validation
bzantium Jan 29, 2025
51990b9
use custom DeepseekV3MLP
bzantium Jan 29, 2025
f4f0ebd
hold code only for checkpoints congifuration; remove redundant
bzantium Jan 30, 2025
4b72b30
revise rope yarn for DeepSeek variation
bzantium Jan 30, 2025
96562c4
Merge branch 'main' into feature/#35425
bzantium Jan 30, 2025
6792cb5
rename DeepSeek-V3
bzantium Jan 30, 2025
3bf3b32
some refactoring
ArthurZucker Jan 31, 2025
24bc8b2
revise load_hook to work properly; make moe func trainable; use llama…
bzantium Jan 31, 2025
5c0cd91
fix attention forward
bzantium Jan 31, 2025
8e994dd
use -1 for not-changing dim when to use exapnd
bzantium Feb 1, 2025
7405a95
refactor DeepseekV3TopkRouter
bzantium Feb 1, 2025
ea3c922
use reshape_for_rope instead of load_hook; revise attention forward f…
bzantium Feb 3, 2025
c813268
register pre_hook and hook both
bzantium Feb 3, 2025
4ab2f9e
make style
bzantium Feb 3, 2025
c5429ec
use n_shared_experts
bzantium Feb 10, 2025
4df42f0
Update src/transformers/models/deepseek_v3/configuration_deepseek_v3.py
bzantium Feb 14, 2025
e0a49ac
Merge branch 'main' of github.com:huggingface/transformers into featu…
Feb 14, 2025
dfd9abc
Merge branch 'feature/#35425' of github.com:bzantium/transformers int…
Feb 14, 2025
ba21b7c
add test file
bzantium Feb 15, 2025
2270173
Merge branch 'feature/#35425' of https://github.com/bzantium/transfor…
bzantium Feb 15, 2025
b5f420b
update modeling_file according to modular file
bzantium Feb 15, 2025
6bd75a9
make style
bzantium Feb 15, 2025
6ccbc66
add mapping for DeepseekV3ForSequenceClassification
bzantium Feb 15, 2025
a1c6274
remove aux_loss_alpha
bzantium Feb 15, 2025
a80462b
add deepseek_v3 for perf
bzantium Feb 15, 2025
dd78f48
add deepseek_v3
bzantium Feb 15, 2025
54481ef
rename test as deepseekv3
bzantium Feb 15, 2025
e0f1c2d
use tiny-deepseek-v3
bzantium Feb 15, 2025
23fb756
Merge branch 'main' into feature/#35425
bzantium Feb 15, 2025
5214741
remove DeepseekV3ForSequenceClassification
bzantium Feb 15, 2025
67f1f0c
cache before padding
bzantium Feb 15, 2025
f264f80
remote output_router_logits
bzantium Feb 18, 2025
d4c6a1b
Revert "remote output_router_logits"
bzantium Feb 18, 2025
c7c8d76
remove output_router_logits
bzantium Feb 18, 2025
0b5ff07
Merge branch 'main' into feature/#35425
bzantium Feb 18, 2025
ba6f7d4
make e_score_correction_bias as buffer
bzantium Feb 18, 2025
d7931b3
skip tests not compatible
bzantium Feb 18, 2025
92bd99c
make style
bzantium Feb 18, 2025
7d81efe
make e_score_correction_bias as buffer
bzantium Feb 18, 2025
b33fdb5
use rope_interleave instead of load_hook
bzantium Feb 19, 2025
7f859f8
skip tests not compatible with MLA
bzantium Feb 19, 2025
397ecf3
add doc for rope_interleave
bzantium Feb 19, 2025
2628438
fix typo
bzantium Feb 19, 2025
af3d328
remove torch.no_grad for selecting topk
bzantium Feb 19, 2025
f0357f9
Merge branch 'main' of github.com:huggingface/transformers into featu…
ArthurZucker Mar 24, 2025
14e7d4e
fix post merge issue
ArthurZucker Mar 24, 2025
5c85490
Merge branch 'main' of github.com:huggingface/transformers into updat…
ArthurZucker Mar 24, 2025
1d8516d
mrege with main and simplify
ArthurZucker Mar 24, 2025
9b4f433
nits
ArthurZucker Mar 24, 2025
abffdfe
final
ArthurZucker Mar 24, 2025
6c7eaa5
small fixes
ArthurZucker Mar 24, 2025
71d47f4
Merge branch 'main' of github.com:huggingface/transformers into updat…
ArthurZucker Mar 24, 2025
9e4965a
fix
ArthurZucker Mar 24, 2025
6bb8802
support TP better
ArthurZucker Mar 25, 2025
426d941
stash
ArthurZucker Mar 25, 2025
d4d60c3
Merge branch 'update-from-pretrained' of github.com:huggingface/trans…
ArthurZucker Mar 25, 2025
f0a8389
changes currently requires
ArthurZucker Mar 25, 2025
4b8a857
remove synch
ArthurZucker Mar 25, 2025
eedbf59
more fixes for TP
ArthurZucker Mar 25, 2025
409f341
temp fix for TP : some attention layers's FP8 scales are too small + …
ArthurZucker Mar 26, 2025
3fb9bea
updates to have generation work!
ArthurZucker Mar 27, 2025
7350a5d
push most of the changes
ArthurZucker Mar 28, 2025
a50c351
reorder functions + call for contributions!
ArthurZucker Mar 28, 2025
24557c3
update readme
ArthurZucker Mar 28, 2025
d7da38b
nits
ArthurZucker Mar 28, 2025
186e32b
update
ArthurZucker Mar 28, 2025
c198b4b
Merge branch 'main' of github.com:huggingface/transformers into featu…
ArthurZucker Mar 28, 2025
ee33cf7
ruff was updated on main
ArthurZucker Mar 28, 2025
f2bb6f9
merge with main and fix copies
ArthurZucker Mar 28, 2025
8cefd1c
revert unrelated changes
ArthurZucker Mar 28, 2025
a8fff20
route all tokens to all experts when testing to avoid no gradient iddues
ArthurZucker Mar 28, 2025
13019a7
finish fixing all tests
ArthurZucker Mar 28, 2025
9b310a1
fixup
ArthurZucker Mar 28, 2025
e3628a3
nit
ArthurZucker Mar 28, 2025
9eb38e6
clean config
ArthurZucker Mar 28, 2025
8cb959b
last readme changes
ArthurZucker Mar 28, 2025
a55630b
nit
ArthurZucker Mar 28, 2025
bce2073
do cnit
ArthurZucker Mar 28, 2025
a1f1f3f
typo
ArthurZucker Mar 28, 2025
d2ae072
last nit
ArthurZucker Mar 28, 2025
372efd6
one more one more
ArthurZucker Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@
title: DeBERTa
- local: model_doc/deberta-v2
title: DeBERTa-v2
- local: model_doc/deepseek_v3
title: DeepSeek-V3
- local: model_doc/dialogpt
title: DialoGPT
- local: model_doc/diffllama
Expand Down
184 changes: 184 additions & 0 deletions docs/source/en/model_doc/deepseek_v3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# DeepSeek-V3

## Overview

The DeepSeek-V3 model was proposed in [DeepSeek-V3 Technical Report](https://arxiv.org/abs/2412.19437) by DeepSeek-AI Team.

The abstract from the paper is the following:
We present DeepSeek-V3, a strong Mixture-of-Experts (MoE) language model with 671B total parameters with 37B activated for each token. To achieve efficient inference and cost-effective training, DeepSeek-V3 adopts Multi-head Latent Attention (MLA) and DeepSeekMoE architectures, which were thoroughly validated in DeepSeek-V2. Furthermore, DeepSeek-V3 pioneers an auxiliary-loss-free strategy for load balancing and sets a multi-token prediction training objective for stronger performance. We pre-train DeepSeek-V3 on 14.8 trillion diverse and high-quality tokens, followed by Supervised Fine-Tuning and Reinforcement Learning stages to fully harness its capabilities. Comprehensive evaluations reveal that DeepSeek-V3 outperforms other open-source models and achieves performance comparable to leading closed-source models. Despite its excellent performance, DeepSeek-V3 requires only 2.788M H800 GPU hours for its full training. In addition, its training process is remarkably stable. Throughout the entire training process, we did not experience any irrecoverable loss spikes or perform any rollbacks. The model checkpoints are available at https://github.com/deepseek-ai/DeepSeek-V3.

## Limitations and call for contribution!

We are super happy to make this code community-powered, and would love to see how you can best optimize the following:

- current implementation uses the "naive" attention compution (so not really MLA)
- current implementation loops through the experts. This should be replaced. Pointers to use `get_packed_weights` from `intetrations/tensor_parallel`.
- current implementation uses the eleuther formula for ROPE, using the orginal one would be more efficient! (should still follow our API)
- static cache is not supported (this should be just a generation config issue / config shape issues)

### Usage tips
The model uses Multi-head Latent Attention (MLA) and DeepSeekMoE architectures for efficient inference and cost-effective training. It employs an auxiliary-loss-free strategy for load balancing and multi-token prediction training objective. The model can be used for various language tasks after being pre-trained on 14.8 trillion tokens and going through Supervised Fine-Tuning and Reinforcement Learning stages.

You can run the model in `FP8` automatically, using 2 nodes of 8 H100 should be more than enough!

```python
# `run_deepseek_v1.py`
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
torch.manual_seed(30)

tokenizer = AutoTokenizer.from_pretrained("deepseek-r1")

chat = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "I'd like to show off how chat templating works!"},
]


model = AutoModelForCausalLM.from_pretrained("deepseek-r1", device_map="auto", torch_dtype=torch.bfloat16)
inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
import time
start = time.time()
outputs = model.generate(inputs, max_new_tokens=50)
print(tokenizer.batch_decode(outputs))
print(time.time()-start)
```
This generated:

``````
<|Assistant|><think>
Okay, the user wants to demonstrate how chat templating works. Let me break down what that means. Chat templating is about structuring the conversation data, especially for models that need specific input formats. Maybe they're referring to something like how messages are formatted with roles (user, assistant, system) in APIs like OpenAI.

First, I should explain what chat templating is. It's the process of formatting conversation data into a structured format that the model can understand. This usually includes roles and content. For example, user messages, assistant responses, and system messages each have their own role tags.

They might want an example. Let me think of a simple conversation. The user says "Hello, how are you?" and the assistant responds "I'm doing great. How can I help you today?" Then the user follows up with wanting to show off chat templating. So the example should include the history and the new message.

In some frameworks, like Hugging Face's Transformers, chat templates are applied using Jinja2 templates. The template might look something like combining system messages, then looping through user and assistant messages with appropriate tags. For instance, using {% for message in messages %} and assigning roles like <|user|>, <|assistant|>, etc.

I should structure the example with the messages array, showing each role and content. Then apply a hypothetical template to convert that into a formatted string the model uses. Also, mention that different models have different templating requirements, like using special tokens or varying role labels.

Wait, the user mentioned "chat templating" in the context of showing off. Maybe they want a practical example they can present. So providing a code snippet or a structured data example would be helpful. Let me outline a typical messages array and then the templated output.

Also, it's important to note that proper templating ensures the model knows the conversation flow, which is crucial for generating coherent responses. Maybe include a note about why it's important, like maintaining context and role-specific processing.

Let me check if there are any common mistakes or things to avoid. For example, not closing tags properly, or mismatching roles. But maybe that's too detailed unless the user asks. Focus on the positive example first.

Putting it all together, the response should have an example messages array, the applied template, and the final formatted string. Maybe use angle brackets or special tokens as placeholders. Also, mention that this helps in training or fine-tuning models with structured data.

I think that's a solid approach. Let me structure it step by step to make it clear.
</think>

Chat templating is a way to structure conversation data (e.g., user/assistant interactions) into a format that language models understand. This is especially important for models trained to handle multi-turn dialogues, where the input must explicitly separate roles (user, assistant, system, etc.) and messages. Let’s break this down with an example!

---

### **Step 1: Raw Conversation History**
Suppose we have this conversation:
- **User**: "Hello, how are you?"
- **Assistant**: "I'm doing great. How can I help you today?"
- **User**: "I'd like to show off how chat templating works!"

---

### **Step 2: Structured Messages**
In frameworks like Hugging Face Transformers or OpenAI, conversations are often formatted as a list of dictionaries with `role` and `content`:
```python
messages = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "I'd like to show off how chat templating works!"},
]
```

---

### **Step 3: Apply a Chat Template**
A **chat template** converts this structured data into a single string formatted for the model. For example, using a Jinja-style template (common in Hugging Face):

```jinja
{% for message in messages %}
{% if message['role'] == 'user' %}
<|user|>{{ message['content'] }}<|end|>
{% elif message['role'] == 'assistant' %}
<|assistant|>{{ message['content'] }}<|end|>
{% endif %}
{% endfor %}
<|assistant|>
```

---

### **Step 4: Final Templated Output**
Applying the template to our `messages` list would produce:
```text
<|user|>Hello, how are you?<|end|>
<|assistant|>I'm doing great. How can I help you today?<|end|>
<|user|>I'd like to show off how chat templating works!<|end|>
<|assistant|>
```

This tells the model:
1. The conversation history (user/assistant turns).
2. The model’s turn to generate a response (`<|assistant|>` at the end).

---

### **Key Notes**:
- **Role Separation**: Tags like `<|user|>` and `<|assistant|>` help the model distinguish speakers.
- **Special Tokens**: Models often use unique tokens (e.g., `<|end|>`) to mark message boundaries.
- **Flexibility**: Templates vary by model (e.g., OpenAI uses `{"role": "user", "content": "..."}` instead of tags).

---

### **Why This Matters**:
- **Consistency**: Ensures the model understands dialogue structure.
- **Context Preservation**: Maintains the flow of multi-turn conversations.
- **Alignment**: Matches the format the model was trained on for better performance.

Want to dive deeper or see a specific framework’s implementation (e.g., OpenAI, Llama, Mistral)? Let me know! 😊<|end▁of▁sentence|>
``````

Use the following to run it
```bash
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0|1 --rdzv-id an_id --rdzv-backend c10d --rdzv-endpoint master_addr:master_port run_deepseek_r1.py
```

If you have:
```bash
[rank0]: ncclInternalError: Internal check failed.
[rank0]: Last error:
[rank0]: Bootstrap : no socket interface found
```
error, it means NCCL was probably not loaded.


## DeepseekV3Config

[[autodoc]] DeepseekV3Config

## DeepseekV3Model

[[autodoc]] DeepseekV3Model
- forward

## DeepseekV3ForCausalLM

[[autodoc]] DeepseekV3ForCausalLM
- forward
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@
],
"models.deberta_v2": ["DebertaV2Config"],
"models.decision_transformer": ["DecisionTransformerConfig"],
"models.deepseek_v3": ["DeepseekV3Config"],
"models.deformable_detr": ["DeformableDetrConfig"],
"models.deit": ["DeiTConfig"],
"models.deprecated": [],
Expand Down Expand Up @@ -2023,6 +2024,13 @@
"DecisionTransformerPreTrainedModel",
]
)
_import_structure["models.deepseek_v3"].extend(
[
"DeepseekV3ForCausalLM",
"DeepseekV3Model",
"DeepseekV3PreTrainedModel",
]
)
_import_structure["models.deformable_detr"].extend(
[
"DeformableDetrForObjectDetection",
Expand Down Expand Up @@ -5546,6 +5554,9 @@
from .models.decision_transformer import (
DecisionTransformerConfig,
)
from .models.deepseek_v3 import (
DeepseekV3Config,
)
from .models.deformable_detr import (
DeformableDetrConfig,
)
Expand Down Expand Up @@ -7175,6 +7186,11 @@
DecisionTransformerModel,
DecisionTransformerPreTrainedModel,
)
from .models.deepseek_v3 import (
DeepseekV3ForCausalLM,
DeepseekV3Model,
DeepseekV3PreTrainedModel,
)
from .models.deformable_detr import (
DeformableDetrForObjectDetection,
DeformableDetrModel,
Expand Down
32 changes: 19 additions & 13 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def w8a8_block_fp8_matmul_compile(
return output.to(output_dtype)


class FP8Linear(nn.Module):
class FP8Linear(nn.Linear):
dtype = torch.float8_e4m3fn

def __init__(
Expand All @@ -304,17 +304,20 @@ def __init__(
device=None,
activation_scheme="dynamic",
):
super().__init__()
super().__init__(in_features, out_features)
self.in_features = in_features
self.out_features = out_features

self.register_buffer("weight", torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))

scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.register_buffer(
"weight_scale_inv", torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
)
if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.weight_scale_inv = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
)
else:
self.register_parameter("weight_scale_inv", None)

self.block_size = block_size

Expand All @@ -330,11 +333,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
else:
# Context manager used to switch among the available cuda devices
with torch.cuda.device(input.device):
qinput, scale = act_quant(input, self.block_size[1])
# with torch.cuda.device(input.device):
qinput, scale = act_quant(input, self.block_size[1])
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch.cuda.synchronize(device=input.device)
# torch.cuda.synchronize(device=self.weight.device)
with torch.cuda.device(input.device):
output = w8a8_block_fp8_matmul_triton(
qinput,
Expand All @@ -344,14 +347,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.block_size,
output_dtype=input.dtype,
)
torch.cuda.synchronize(device=input.device)
torch.cuda.synchronize()
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)


def _replace_with_fp8_linear(
model,
tp_plan=None,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
Expand All @@ -378,10 +382,12 @@ def _replace_with_fp8_linear(
block_size=quantization_config.weight_block_size,
)
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO

if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fp8_linear(
module,
tp_plan,
modules_to_not_convert,
current_key_name,
quantization_config,
Expand All @@ -404,9 +410,9 @@ def replace_with_fp8_linear(
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))

model, has_been_replaced = _replace_with_fp8_linear(
model,
tp_plan=model._tp_plan,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
Expand Down
12 changes: 9 additions & 3 deletions src/transformers/integrations/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
distribute_module(
module,
device_mesh,
partial(self._prepare_input_fn),
partial(self._prepare_output_fn),
partial(self._prepare_input_fn, None, None),
partial(self._prepare_output_fn, None, None),
)


Expand Down Expand Up @@ -484,7 +484,12 @@ def __init__(self):
# 1. We add hooks to the layer being loaded:
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
tp_layer.prepare_module_tp(module, device_mesh)
try:
tp_layer.prepare_module_tp(module, device_mesh)
except NotImplementedError as e:
print(
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
)

# 2. We add hooks to the parrent module if needed
if "." in layer_name:
Expand Down Expand Up @@ -531,6 +536,7 @@ def shard_and_distribute_module(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
else:
# TODO log no plan modules in set
param = param[...].to(param_casting_dtype)
if is_contiguous:
param = param.contiguous()
Expand Down
Loading