Skip to content

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Jan 13, 2026

Description

Background

DeepSeek V3.2 differs from DeepSeek V3 solely in the attention mechanism, aiming for efficiency in long-context scenario. While DeepSeek V3 uses Multi-head Latent Attention (MLA), DeepSeek V3.2 uses DeepSeek Sparse Attention (DSA). DSA augments MLA with two components:

  • Indexer: parametric, qk product to get index score
  • Top-k token selection: non-parametric, select top-k key/value for each query, introducing sparsity to qkv attention

What this PR does

1. Naive implementation of DeepSeek Sparse Attention (DSA)

  • Indexer:

    • qk product: currently implemented with dot product to get index scores. To be optimized.
    • (minor) RoPE: indexer applies partial RoPE to q and k based on YaRN extension. It uses the same YaRN frequency as MLA, but with concatenated layout rather than interleaved layout.
    • Based on index scores, get top-k indices and index mask
  • Top-k selection for qkv attention:

    • This is currently implemented inside dot product attention, by adding index mask to regular attention mask. To be optimized.
  • training only (no prefill / decode)

  • See changes attention_mla.py, attention_op.py

2. Onboard deepseek3.2-671b config

  • deepseek3.2-671b.yml
  • deepseek v3.2 vs. v3: HF config diff: additional config for indexer
"index_head_dim": 128, "index_n_heads": 64, "index_topk": 2048,
  • number of parameter: (1) Similar to v3, HF safetensor of v3.2 contains an extra layer for MTP which we omit. (2) Note that indexer contains extra parameter. (3) By counting, v3 has 671026419200 (671.03B) and v3.2 has671877944064 (671.88B) parameters.

3. unit test: ahead-of-time train compile for deepseek3.2-671b

4. unit test: compare output against torch code for Indexer and MLA

  • check_deepseek32_vs_reference.py
  • The original torch reference can only run on GPU, due to quantization and fp8 kernel (act_quant, fp8_gemm, fp8_index). https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/inference/model.py
  • We adapted the reference to run on CPU
    • Remove quantization and use float32 for dtype and weight_dytpe
    • Replace fp8 kernel with naive dot product
    • (minor) Replace fast_hadamard_transform.hadamard_transform with F.linear

Reference

Future work

  • verify end-to-end training logits for deepseek3.2
  • more efficient implementation of DSA

Tests

Unit test against torch code (adapted from reference): indexer, MLA

python3 -m pytest -v --pyargs tests.check_deepseek32_vs_reference -rP -s

Unit test for train compile

python3 -m pytest -v --pyargs tests.train_compile_test -rP -s -k "test_deepseek32"

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 13, 2026

Codecov Report

❌ Patch coverage is 51.25000% with 39 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/attention_mla.py 53.94% 31 Missing and 4 partials ⚠️
src/MaxText/layers/attention_op.py 0.00% 2 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator Author

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! I took a look at indexer part, and overall it looks good for functionality. It also has indexer logit kernel for performance, I will take a look there.

I will take a look at MLA part shortly.


# Internal Indexer Cache (distinct from main MLA KV Cache)
# Shape: [Batch, MaxLen, HeadDim]
self.k_cache = nnx.Variable(jnp.zeros((config.max_target_length, self.head_dim), dtype=jnp.bfloat16))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is 2D, while note with [Batch, MaxLen, HeadDim]? We could remove this if not needed for training so far.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove the k_cache as focusing on training for now.

k = self._apply_partial_rope(k, positions)
k = k.squeeze(2) # Back to [B, S, D]

# 3. Cache Update (Functional NNX update)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also apply rotate_activation from reference? It seems related to quantization strategy, and zero accuracy change without it.

Copy link
Collaborator

@shuningjin shuningjin Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What rotate_activation does: hadamard transform

  • rotate_activation(q) @ rotate_activation(k).T gives the same result as q @ k.T (see below)
  • in their original implementation, they use rotate_activation before fp8-quantization for stability.

We don't need it

  • we are not using quantization for now, and rotate_activation does not affect result, will omit this from our implementation to avoid additional computation
  • unit test can pass: maxtext w/o rotate_activation MATCH reference w/ rotate_activation

Equivalence:

Property of Hadamard matrix: For $n\times n$ Hadamard matrix $H$, $H H^\top = n I_n$. (By transposing both sides, we also have $H^\top H = n I_n$)

  • Let $\tilde{H} = H/ \sqrt{n}$. We have $\tilde{H} \tilde{H}^\top = \tilde{H}^\top \tilde{H} = I_n$. That is, the scaled Hadamard matrix $\tilde{H}$ is orthogonal.

Property of Hadamard Transform: mathematically, it is equivalent to multiplying input vector by a Hadamard matrix with some scale

  • $(\tilde{H} k)^\top (\tilde{H} q) =k^\top \tilde{H}^\top \tilde{H} q = k^\top I_n q = k^\top q$. That is, Hadamard Transform does not change qk product.

Code example: product of scaled Hadamard matrix is identity

import torch, scipy
dim = 4
# Hadamard matrix, dim x dim
H = torch.tensor(scipy.linalg.hadamard(dim), dtype=torch.float32) 
print(H)
# Scaled Hadamard matrix, orthogonal matrix
scale = dim ** -0.5  # unitary scale
H_scaled = H * scale 
# product is identity
print(H_scaled @ H_scaled.T)
print(H_scaled.T @ H_scaled)

Code example: qk product with or without rotate_activation (hadamard transform) are the same

def rotate_activation(x: torch.Tensor) -> torch.Tensor:
  hidden_size = x.size(-1)
  return F.linear(x, torch.tensor(scipy.linalg.hadamard(hidden_size), dtype=x.dtype)) * hidden_size**-0.5

b, s, h, d = 1, 2, 1, 4
q = torch.randn(b, s, h, d)
k = torch.randn(b, s, d)
out1 = torch.einsum("bshd, btd -> bsth", q, k)
out2 = torch.einsum("bshd, btd -> bsth", rotate_activation(q), rotate_activation(k))
torch.testing.assert_close(out1, out2)

Reference:

self.k_cache.value = updated_cache

# Active Keys: [B, TotalLen, D]
k_active = jax.lax.dynamic_slice(updated_cache, (0, 0, 0), (bsz, end_pos, self.head_dim))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this same with start_pos:end_pos for end_pos?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove the k_cache as focusing on training for now.

seq_idx = jnp.arange(seqlen)[None, :, None]

# JAX scatter update
bias_mask = bias_mask.at[batch_idx, seq_idx, topk_indices].set(0.0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall this set will be very inefficient. Will need to consider other operations later, leveraging matmul.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed from set to jnp.where as you suggested in the other comment. Can explore other operations in the future.

# We check a config flag to see if we are in Sparse/DeepSeek3.2 mode
self.use_sparse_indexer = getattr(config, "use_sparse_indexer", False)
if self.use_sparse_indexer:
indexer_rope = copy.copy(self.rotary_embedding)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need copy of rotary_embedding?

Copy link
Collaborator

@shuningjin shuningjin Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both MLA and Indexer uses RoPE with YaRN extension, using the same YaRN frequency. They differs in the input layout: MLA yarn uses interleave=true and indexer uses interleave=false. Making a copy to keep the two process isolated.

To clarify the use of YaRN, I have added comments in attention_mla and YarnRotaryEmbedding.

out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
else:
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values)
# ds3.2, MHA mode for train / prefill, TODO: MQA model for decode (mathematically equivalent but speed faster)?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In MQA, the size of the Key-Value (KV) cache is reduced by a factor equal to the number of heads.

You may have this diff already, but just FYI. V3 MLA vs. V3.2: https://diff.googleplex.com/#key=3JSmf20vQG8U

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the diff, this is really helpful!

I will leave inference-related code to future efforts.

batch_idx = jnp.arange(bsz)[:, None, None]
seq_idx = jnp.arange(seqlen)[None, :, None]
# JAX scatter update
index_mask = index_mask.at[batch_idx, seq_idx, topk_indices].set(0.0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later - We could update it to jnp.where in this case. set() is a scatter operation, often less efficient.

# Assuming:
# topk_indices shape: [B, S, K]
# T is the target dimension size (index_score.shape[-1])

def get_mask_efficient(topk_indices, T, default_value):
    # 1. Create a range [0, 1, ..., T-1]
    # 2. Broadcast compare against [B, S, K] to get [B, S, K, T]
    # 3. Use .any() to see if a T-index is present in any of the K slots
    is_topk = (jnp.arange(T) == topk_indices[..., None]).any(axis=-2)
    
    # 4. Use where to select between 0.0 and the mask value
    return jnp.where(is_topk, 0.0, default_value)
    ```

Copy link
Collaborator

@shuningjin shuningjin Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To generate mask from top-k indices, I test with jax.set (v1) and jax.where (v2) functions. On cpu, v1 is better. On tpu, v2 is better for both speed and memory. Here is result. So I will go with v2. Thanks for your great suggestion!

@shuningjin shuningjin changed the title [DO NO MERGE] Draft for sparse DeepSeek3.2: Onboard sparse attention Jan 17, 2026
@shuningjin shuningjin self-assigned this Jan 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants