-
Notifications
You must be signed in to change notification settings - Fork 452
DeepSeek3.2: Onboard sparse attention #2933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! I took a look at indexer part, and overall it looks good for functionality. It also has indexer logit kernel for performance, I will take a look there.
I will take a look at MLA part shortly.
src/MaxText/layers/attention_mla.py
Outdated
|
|
||
| # Internal Indexer Cache (distinct from main MLA KV Cache) | ||
| # Shape: [Batch, MaxLen, HeadDim] | ||
| self.k_cache = nnx.Variable(jnp.zeros((config.max_target_length, self.head_dim), dtype=jnp.bfloat16)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is 2D, while note with [Batch, MaxLen, HeadDim]? We could remove this if not needed for training so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove the k_cache as focusing on training for now.
src/MaxText/layers/attention_mla.py
Outdated
| k = self._apply_partial_rope(k, positions) | ||
| k = k.squeeze(2) # Back to [B, S, D] | ||
|
|
||
| # 3. Cache Update (Functional NNX update) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also apply rotate_activation from reference? It seems related to quantization strategy, and zero accuracy change without it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What rotate_activation does: hadamard transform
rotate_activation(q) @ rotate_activation(k).Tgives the same result asq @ k.T(see below)- in their original implementation, they use
rotate_activationbefore fp8-quantization for stability.
We don't need it
- we are not using quantization for now, and
rotate_activationdoes not affect result, will omit this from our implementation to avoid additional computation - unit test can pass: maxtext w/o rotate_activation MATCH reference w/ rotate_activation
Equivalence:
Property of Hadamard matrix: For
- Let
$\tilde{H} = H/ \sqrt{n}$ . We have$\tilde{H} \tilde{H}^\top = \tilde{H}^\top \tilde{H} = I_n$ . That is, the scaled Hadamard matrix$\tilde{H}$ is orthogonal.
Property of Hadamard Transform: mathematically, it is equivalent to multiplying input vector by a Hadamard matrix with some scale
-
$(\tilde{H} k)^\top (\tilde{H} q) =k^\top \tilde{H}^\top \tilde{H} q = k^\top I_n q = k^\top q$ . That is, Hadamard Transform does not change qk product.
Code example: product of scaled Hadamard matrix is identity
import torch, scipy
dim = 4
# Hadamard matrix, dim x dim
H = torch.tensor(scipy.linalg.hadamard(dim), dtype=torch.float32)
print(H)
# Scaled Hadamard matrix, orthogonal matrix
scale = dim ** -0.5 # unitary scale
H_scaled = H * scale
# product is identity
print(H_scaled @ H_scaled.T)
print(H_scaled.T @ H_scaled)
Code example: qk product with or without rotate_activation (hadamard transform) are the same
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
hidden_size = x.size(-1)
return F.linear(x, torch.tensor(scipy.linalg.hadamard(hidden_size), dtype=x.dtype)) * hidden_size**-0.5
b, s, h, d = 1, 2, 1, 4
q = torch.randn(b, s, h, d)
k = torch.randn(b, s, d)
out1 = torch.einsum("bshd, btd -> bsth", q, k)
out2 = torch.einsum("bshd, btd -> bsth", rotate_activation(q), rotate_activation(k))
torch.testing.assert_close(out1, out2)
Reference:
- https://en.wikipedia.org/wiki/Hadamard_matrix
- https://en.wikipedia.org/wiki/Hadamard_transform
- https://github.com/Dao-AILab/fast-hadamard-transform:
hadamard_transform(x, scale=1.0)Equivalent toF.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
src/MaxText/layers/attention_mla.py
Outdated
| self.k_cache.value = updated_cache | ||
|
|
||
| # Active Keys: [B, TotalLen, D] | ||
| k_active = jax.lax.dynamic_slice(updated_cache, (0, 0, 0), (bsz, end_pos, self.head_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this same with start_pos:end_pos for end_pos?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove the k_cache as focusing on training for now.
src/MaxText/layers/attention_mla.py
Outdated
| seq_idx = jnp.arange(seqlen)[None, :, None] | ||
|
|
||
| # JAX scatter update | ||
| bias_mask = bias_mask.at[batch_idx, seq_idx, topk_indices].set(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recall this set will be very inefficient. Will need to consider other operations later, leveraging matmul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed from set to jnp.where as you suggested in the other comment. Can explore other operations in the future.
| # We check a config flag to see if we are in Sparse/DeepSeek3.2 mode | ||
| self.use_sparse_indexer = getattr(config, "use_sparse_indexer", False) | ||
| if self.use_sparse_indexer: | ||
| indexer_rope = copy.copy(self.rotary_embedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need copy of rotary_embedding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both MLA and Indexer uses RoPE with YaRN extension, using the same YaRN frequency. They differs in the input layout: MLA yarn uses interleave=true and indexer uses interleave=false. Making a copy to keep the two process isolated.
To clarify the use of YaRN, I have added comments in attention_mla and YarnRotaryEmbedding.
src/MaxText/layers/attention_mla.py
Outdated
| out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out | ||
| else: | ||
| out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values) | ||
| # ds3.2, MHA mode for train / prefill, TODO: MQA model for decode (mathematically equivalent but speed faster)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In MQA, the size of the Key-Value (KV) cache is reduced by a factor equal to the number of heads.
You may have this diff already, but just FYI. V3 MLA vs. V3.2: https://diff.googleplex.com/#key=3JSmf20vQG8U
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the diff, this is really helpful!
I will leave inference-related code to future efforts.
src/MaxText/layers/attention_mla.py
Outdated
| batch_idx = jnp.arange(bsz)[:, None, None] | ||
| seq_idx = jnp.arange(seqlen)[None, :, None] | ||
| # JAX scatter update | ||
| index_mask = index_mask.at[batch_idx, seq_idx, topk_indices].set(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Later - We could update it to jnp.where in this case. set() is a scatter operation, often less efficient.
# Assuming:
# topk_indices shape: [B, S, K]
# T is the target dimension size (index_score.shape[-1])
def get_mask_efficient(topk_indices, T, default_value):
# 1. Create a range [0, 1, ..., T-1]
# 2. Broadcast compare against [B, S, K] to get [B, S, K, T]
# 3. Use .any() to see if a T-index is present in any of the K slots
is_topk = (jnp.arange(T) == topk_indices[..., None]).any(axis=-2)
# 4. Use where to select between 0.0 and the mask value
return jnp.where(is_topk, 0.0, default_value)
```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To generate mask from top-k indices, I test with jax.set (v1) and jax.where (v2) functions. On cpu, v1 is better. On tpu, v2 is better for both speed and memory. Here is result. So I will go with v2. Thanks for your great suggestion!
94b73d8 to
fe2ea34
Compare
Description
Background
DeepSeek V3.2 differs from DeepSeek V3 solely in the attention mechanism, aiming for efficiency in long-context scenario. While DeepSeek V3 uses Multi-head Latent Attention (MLA), DeepSeek V3.2 uses DeepSeek Sparse Attention (DSA). DSA augments MLA with two components:
What this PR does
1. Naive implementation of DeepSeek Sparse Attention (DSA)
Indexer:
Top-k selection for qkv attention:
training only (no prefill / decode)
See changes
attention_mla.py,attention_op.py2. Onboard deepseek3.2-671b config
deepseek3.2-671b.yml671026419200(671.03B) and v3.2 has671877944064(671.88B) parameters.3. unit test: ahead-of-time train compile for deepseek3.2-671b
4. unit test: compare output against torch code for Indexer and MLA
check_deepseek32_vs_reference.pyReference
Future work
Tests
Unit test against torch code (adapted from reference): indexer, MLA
Unit test for train compile
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.