-
Notifications
You must be signed in to change notification settings - Fork 452
Support attention data parallelism #2955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
9f95406 to
6f98db7
Compare
01e8647 to
1cacf8f
Compare
1cacf8f to
863f779
Compare
| # Parallelism | ||
| shard_mode: "auto" # can be either auto or explicit | ||
| mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] | ||
| mesh_axes: ['data', 'stage', 'attn_dp', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need attn_dp here since we are not using it for training.
| # ATTENTION DATA PARALLELISM AUTO-CALCULATION | ||
| # When num_kv_heads < tensor_parallelism, KV heads would be duplicated across devices, | ||
| # wasting KV cache memory. Use attention DP to reduce per-device num_kv_heads instead. | ||
| # enable_attn_dp = getattr(self, "enable_attn_dp", False) | ||
| if self.ici_tensor_parallelism > self.num_kv_heads: | ||
| # Check if user explicitly set attn_dp_parallelism (not default value of 1) | ||
| user_set_ici_attn_dp = self.ici_attn_dp_parallelism != 1 | ||
| user_set_dcn_attn_dp = self.dcn_attn_dp_parallelism != 1 | ||
|
|
||
| if user_set_ici_attn_dp or user_set_dcn_attn_dp: | ||
| raise ValueError( | ||
| f"attn_dp_parallelism is auto-calculated and should not be set explicitly. " | ||
| f"Found ici_attn_dp_parallelism={self.ici_attn_dp_parallelism}, " | ||
| f"dcn_attn_dp_parallelism={self.dcn_attn_dp_parallelism}. " | ||
| f"Please remove these settings or set enable_attn_dp=False to use manual values." | ||
| ) | ||
|
|
||
| # Auto-calculate attn_dp based on KV heads and tensor parallelism | ||
| num_kv_heads = self.num_kv_heads | ||
| tp_size = self.ici_tensor_parallelism if self.ici_tensor_parallelism > 0 else 1 | ||
|
|
||
| if 0 < num_kv_heads < tp_size: | ||
| calculated_attn_dp = tp_size // num_kv_heads | ||
| logger.info( | ||
| "Auto-calculating attention DP: num_kv_heads=%s < ici_tensor_parallelism=%s. " | ||
| "Setting ici_attn_dp_parallelism=%s and " | ||
| "ici_tensor_parallelism=%s to avoid KV head duplication.", | ||
| num_kv_heads, | ||
| tp_size, | ||
| calculated_attn_dp, | ||
| num_kv_heads, | ||
| ) | ||
| self.ici_attn_dp_parallelism = calculated_attn_dp | ||
| self.ici_tensor_parallelism = num_kv_heads | ||
| else: | ||
| logger.info( | ||
| "Attention DP not needed: num_kv_heads=%s >= ici_tensor_parallelism=%s. " | ||
| "Keeping ici_attn_dp_parallelism=1.", | ||
| num_kv_heads, | ||
| tp_size, | ||
| ) | ||
| self.ici_attn_dp_parallelism = 1 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic sets the ici_attn_dp_parallelism similar to how its done in tpu-inference. Since the mesh itself is being initialized in tpu-inference do we need this code in MaxText? I don't think we will have attn_dp in non-vLLM code paths right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is also true of other changes in types.py
| if self.config.attention == "vllm_rpa": | ||
| # vLLM uses 'model' as the tensor parallelism axis name | ||
| self._tensor_parallelism_name = "model" | ||
| self._tensor_parallelism_name = ("model", "attn_dp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this also work for the case where attn_dp is 1?
Description
This work was done in collaboration with @NicoGrande
This PR introduces attention data parallelism (attn_dp) to optimize memory efficiency when the number of KV heads is less than tensor parallelism. The attention DP degree is auto-calculated based on the ratio of tensor parallelism to KV heads, ensuring optimal sharding without manual configuration.
New logical axes (attn_activation_length, attn_activation_embed) and corresponding sharding rules have been added to support attention-specific tensor partitioning separate from the rest of the model
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Ran Llama3.1-8B by reducing num_kv_heads to 2 and ici_tensor_parallelism=8. This will auto-calculate attn_dp and set it to 4.
vllm mesh:
mesh=Mesh('data': 1, 'attn_dp': 4, 'expert': 1, 'model': 2, axis_types=(Auto, Auto, Auto, Auto))Attention shardings:
Weight shardings:
Running Qwen3-30b-moe with tp=8, it has 4 kv heads so model=4 and attn_dp=2
Command:
Output: https://paste.googleplex.com/5980793354715136
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.