Skip to content

Conversation

@khatwanimohit
Copy link
Collaborator

@khatwanimohit khatwanimohit commented Jan 15, 2026

Description

This work was done in collaboration with @NicoGrande

This PR introduces attention data parallelism (attn_dp) to optimize memory efficiency when the number of KV heads is less than tensor parallelism. The attention DP degree is auto-calculated based on the ratio of tensor parallelism to KV heads, ensuring optimal sharding without manual configuration.
New logical axes (attn_activation_length, attn_activation_embed) and corresponding sharding rules have been added to support attention-specific tensor partitioning separate from the rest of the model

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Ran Llama3.1-8B by reducing num_kv_heads to 2 and ici_tensor_parallelism=8. This will auto-calculate attn_dp and set it to 4.
vllm mesh: mesh=Mesh('data': 1, 'attn_dp': 4, 'expert': 1, 'model': 2, axis_types=(Auto, Auto, Auto, Auto))

Attention shardings:

llama2.py:158] bfloat16[96,1,4096]............................................................. (None, None, ('model', 'attn_dp')).
attentions.py:1014] bfloat16[96,1,4096]............................................................. (None, 'attn_dp', 'model').
attentions.py:1078] bfloat16[96,1,32,128]........................................................... (None, 'attn_dp', 'model', None).
attentions.py:1079] bfloat16[96,1,8,128]............................................................ (None, 'attn_dp', 'model', None).
linears.py:527] bfloat16[96,1,14336]............................................................ (None, None, ('model', 'attn_dp')).

Weight shardings:

maxtext_utils.py:1197] decoder/decoder_norm/scale/value................................................ bfloat16[4096] (('model', 'attn_dp'),)
maxtext_utils.py:1197] decoder/layers_0/mlp/wi_0/kernel/value.......................................... bfloat16[4096,14336] (None, ('model', 'attn_dp'))
maxtext_utils.py:1197] decoder/layers_0/mlp/wi_1/kernel/value.......................................... bfloat16[4096,14336] (None, ('model', 'attn_dp'))
maxtext_utils.py:1197] decoder/layers_0/mlp/wo/kernel/value............................................ bfloat16[14336,4096] (('model', 'attn_dp'), None)
maxtext_utils.py:1197] decoder/layers_0/post_self_attention_layer_norm/scale/value..................... bfloat16[4096] (('model', 'attn_dp'),)
maxtext_utils.py:1197] decoder/layers_0/pre_self_attention_layer_norm/scale/value...................... bfloat16[4096] (('model', 'attn_dp'),)
maxtext_utils.py:1197] decoder/layers_0/self_attention/key/kernel/value................................ bfloat16[4096,8,128] (None, 'model', None)
maxtext_utils.py:1197] decoder/layers_0/self_attention/out/kernel/value................................ bfloat16[32,128,4096] ('model', None, None)
maxtext_utils.py:1197] decoder/layers_0/self_attention/query/kernel/value.............................. bfloat16[4096,32,128] (None, 'model', None)
maxtext_utils.py:1197] decoder/layers_0/self_attention/value/kernel/value.............................. bfloat16[4096,8,128] (None, 'model', None)
decoder/logits_dense/kernel/value............................................... bfloat16[4096,128256] (None, ('model', 'attn_dp'))
token_embedder/embedding/value.................................................. bfloat16[128256,4096] (('model', 'attn_dp'), None)

Running Qwen3-30b-moe with tp=8, it has 4 kv heads so model=4 and attn_dp=2
Command:

NEW_MODEL_DESIGN=1 python3 -m MaxText.vllm_decode     --model_name qwen3-30b-a3b  --hf_model_name Qwen/Qwen3-30B-A3B     --hf_config_path src/MaxText/integration/vllm/maxtext_vllm_adapter  --ici_tensor_parallelism 8  --gpu_memory_utilization 0.5  --prompt "Suggest some famous landmarks in London" --debug_sharding true --enable_dp_attention true --load_parameters_path gs://parambole-qwen3-moe-verification/unscanned/qwen3-30b-a3b-thinking-2507/14_08_2025/0/items 2>&1 | tee attn_output

Output: https://paste.googleplex.com/5980793354715136

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 15, 2026

Codecov Report

❌ Patch coverage is 19.23077% with 21 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/maxtext_utils.py 12.50% 7 Missing ⚠️
src/MaxText/layers/moe.py 0.00% 5 Missing and 1 partial ⚠️
src/MaxText/vllm_decode.py 0.00% 5 Missing ⚠️
src/MaxText/model_creation_utils.py 25.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

# Parallelism
shard_mode: "auto" # can be either auto or explicit
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
mesh_axes: ['data', 'stage', 'attn_dp', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need attn_dp here since we are not using it for training.

Comment on lines +2179 to +2221
# ATTENTION DATA PARALLELISM AUTO-CALCULATION
# When num_kv_heads < tensor_parallelism, KV heads would be duplicated across devices,
# wasting KV cache memory. Use attention DP to reduce per-device num_kv_heads instead.
# enable_attn_dp = getattr(self, "enable_attn_dp", False)
if self.ici_tensor_parallelism > self.num_kv_heads:
# Check if user explicitly set attn_dp_parallelism (not default value of 1)
user_set_ici_attn_dp = self.ici_attn_dp_parallelism != 1
user_set_dcn_attn_dp = self.dcn_attn_dp_parallelism != 1

if user_set_ici_attn_dp or user_set_dcn_attn_dp:
raise ValueError(
f"attn_dp_parallelism is auto-calculated and should not be set explicitly. "
f"Found ici_attn_dp_parallelism={self.ici_attn_dp_parallelism}, "
f"dcn_attn_dp_parallelism={self.dcn_attn_dp_parallelism}. "
f"Please remove these settings or set enable_attn_dp=False to use manual values."
)

# Auto-calculate attn_dp based on KV heads and tensor parallelism
num_kv_heads = self.num_kv_heads
tp_size = self.ici_tensor_parallelism if self.ici_tensor_parallelism > 0 else 1

if 0 < num_kv_heads < tp_size:
calculated_attn_dp = tp_size // num_kv_heads
logger.info(
"Auto-calculating attention DP: num_kv_heads=%s < ici_tensor_parallelism=%s. "
"Setting ici_attn_dp_parallelism=%s and "
"ici_tensor_parallelism=%s to avoid KV head duplication.",
num_kv_heads,
tp_size,
calculated_attn_dp,
num_kv_heads,
)
self.ici_attn_dp_parallelism = calculated_attn_dp
self.ici_tensor_parallelism = num_kv_heads
else:
logger.info(
"Attention DP not needed: num_kv_heads=%s >= ici_tensor_parallelism=%s. "
"Keeping ici_attn_dp_parallelism=1.",
num_kv_heads,
tp_size,
)
self.ici_attn_dp_parallelism = 1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic sets the ici_attn_dp_parallelism similar to how its done in tpu-inference. Since the mesh itself is being initialized in tpu-inference do we need this code in MaxText? I don't think we will have attn_dp in non-vLLM code paths right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is also true of other changes in types.py

if self.config.attention == "vllm_rpa":
# vLLM uses 'model' as the tensor parallelism axis name
self._tensor_parallelism_name = "model"
self._tensor_parallelism_name = ("model", "attn_dp")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also work for the case where attn_dp is 1?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants