-
Notifications
You must be signed in to change notification settings - Fork 621
Fix: Add mask_indptr conversion in BatchPrefillWithPagedKVCacheWrapper.plan() #2201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
… BatchPrefillWithPagedKVCacheWrapper.plan()
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughExtracts indptr computation for packed masks into a new internal helper Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Dutch-voyage, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a potential issue in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request introduces a new helper function, _get_indptr_for_packed_mask, in flashinfer/quantization.py to compute index pointers for packed masks, which refactors existing logic within segment_packbits. This new function is then imported and utilized in flashinfer/prefill.py to correctly handle mask_indptr when a packed_custom_mask is provided without a custom_mask, alongside a TODO comment regarding mask consistency. A review comment requested adding a return type hint to the newly defined _get_indptr_for_packed_mask function for better code clarity and maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)
1774-1791: Critical: Logic error when both masks are provided.When a caller provides both
custom_maskandpacked_custom_mask, the code:
- Computes
mask_indptrfor unpacked segments (lines 1775-1780)- Skips the conversion at line 1782 because
custom_mask is Noneevaluates toFalse- Skips the packing at line 1785 because
packed_custom_mask is Noneevaluates toFalse- Result:
mask_indptrpoints to unpacked segment boundaries, butpacked_custom_mask(which is packed) will be used by the attention kernel, causing misaligned memory access and incorrect maskingApply this diff to fix the logic by either converting
mask_indptrwheneverpacked_custom_maskis present, or validating mutual exclusivity:Option 1 (preferred): Always convert mask_indptr when packed_custom_mask is provided
if custom_mask is not None or packed_custom_mask is not None: mask_indptr = _compute_page_mask_indptr( qo_indptr, paged_kv_indptr, paged_kv_last_page_len, page_size, ) -if packed_custom_mask is not None and custom_mask is None: +if packed_custom_mask is not None: + # packed_custom_mask is provided, convert mask_indptr to packed format mask_indptr = _get_indptr_for_packed_mask(mask_indptr) - -if packed_custom_mask is None and custom_mask is not None: +elif custom_mask is not None: # create packed custom mask from custom mask packed_custom_mask, mask_indptr = segment_packbits( custom_mask.contiguous().view(-1), mask_indptr, bitorder="little", )Option 2: Add parameter validation
+if custom_mask is not None and packed_custom_mask is not None: + raise ValueError( + "Cannot provide both custom_mask and packed_custom_mask. " + "Either provide custom_mask (will be packed internally) or " + "packed_custom_mask (pre-packed with packbits/segment_packbits)." + ) + if custom_mask is not None or packed_custom_mask is not None: mask_indptr = _compute_page_mask_indptr( qo_indptr, paged_kv_indptr, paged_kv_last_page_len, page_size, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/prefill.py(2 hunks)flashinfer/quantization.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/prefill.py (1)
flashinfer/quantization.py (3)
packbits(47-78)segment_packbits(90-144)_get_indptr_for_packed_mask(80-87)
🔇 Additional comments (3)
flashinfer/quantization.py (2)
80-87: LGTM! Clean extraction of indptr conversion logic.The helper correctly computes the packed indptr by calculating segment lengths, determining the packed byte count per segment, and using cumsum to build the new indptr.
136-136: LGTM! Good refactoring to use the extracted helper.This change centralizes the indptr conversion logic and makes it reusable.
flashinfer/prefill.py (1)
39-39: LGTM! Import is correct.The internal helper is properly imported for use in the plan() method.
| self._num_qo_heads = num_qo_heads | ||
| self._num_kv_heads = num_kv_heads | ||
|
|
||
| # TODO: should we check consistency between custom_mask and packed_custom_mask? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolve the TODO by adding validation.
The TODO correctly identifies the need to check consistency between custom_mask and packed_custom_mask. As shown in the previous comment, allowing both to be provided leads to a critical bug. This TODO should be resolved by implementing either mutual exclusivity validation or preferential handling of packed_custom_mask.
🤖 Prompt for AI Agents
In flashinfer/prefill.py around line 1772, implement validation that prevents
inconsistent usage of custom_mask and packed_custom_mask: if both are provided
raise a ValueError stating they are mutually exclusive; if packed_custom_mask is
provided, unpack it and validate its shape/dtype matches expected mask shape and
values; if custom_mask is provided validate its shape/dtype and contents as
well; ensure any downstream code uses the validated mask object (or the unpacked
one) so behavior is deterministic.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/quantization.py (1)
80-87: Refactoring looks solid.The extracted helper correctly computes the new indptr for packed segments using the ceiling division formula
(seglen + 7) // 8. The logic is sound and appropriately reusable across modules.Optionally, consider adding a brief docstring to clarify the helper's purpose:
def _get_indptr_for_packed_mask( indptr: torch.Tensor, ) -> torch.Tensor: + """Compute indptr for packed (bit-packed into uint8) segments. + + Given an indptr array defining segment boundaries, returns a new indptr + where each segment length is the ceiling division by 8 (for bit-packing). + """ seglen = indptr[1:] - indptr[:-1]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/quantization.py(2 hunks)
🔇 Additional comments (1)
flashinfer/quantization.py (1)
136-136: Clean delegation to the extracted helper.The refactored code correctly uses
_get_indptr_for_packed_maskand preserves the existing behavior ofsegment_packbits.
| def _get_indptr_for_packed_mask( | ||
| indptr: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| seglen = indptr[1:] - indptr[:-1] | ||
| packed_len = (seglen + 7) // 8 | ||
| indptr_new = torch.zeros(len(indptr), dtype=indptr.dtype, device=indptr.device) | ||
| indptr_new[1:] = torch.cumsum(packed_len, 0) | ||
| return indptr_new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
find . -type f -name '*test*.py' | head -20Repository: flashinfer-ai/flashinfer
Length of output: 925
🏁 Script executed:
rg -l 'packed_custom_mask' --type=pyRepository: flashinfer-ai/flashinfer
Length of output: 160
🏁 Script executed:
rg -n 'packed_custom_mask' --type=py -A 3 -B 3Repository: flashinfer-ai/flashinfer
Length of output: 20383
🏁 Script executed:
cat -n tests/utils/test_jit_example.py | sed -n '60,85p'Repository: flashinfer-ai/flashinfer
Length of output: 1287
🏁 Script executed:
rg -n 'def test.*prefill' tests/attention/ --type=py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1977
🏁 Script executed:
rg -n 'packed_custom_mask' tests/attention/ --type=py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg -n 'plan\(' tests/ --type=py -A 3 | grep -A 3 'packed_custom_mask'Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg -n 'batch_prefill_with_paged_kv_cache\(' tests/ --type=py -A 5 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1959
🏁 Script executed:
# Check what parameters are actually passed in test files
rg -n 'packed_custom_mask' tests/ --type=py -B 10 -A 2 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 1218
🏁 Script executed:
rg -n 'def test_batch_prefill_with_paged_kv_cache_custom_mask' tests/attention/test_batch_prefill_kernels.py -A 30Repository: flashinfer-ai/flashinfer
Length of output: 1035
🏁 Script executed:
# Look at how custom_mask is used in the test
rg -n 'custom_mask' tests/attention/test_batch_prefill_kernels.py -B 2 -A 2 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2465
🏁 Script executed:
sed -n '1770,1800p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 1355
🏁 Script executed:
sed -n '1193,1200p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 379
Add test coverage for the packed_custom_mask-only code path.
The new code at BatchPrefillWithPagedKVCacheWrapper.__init__ (line 1782-1783) handles the case where packed_custom_mask is provided without custom_mask, calling _get_indptr_for_packed_mask. Existing tests only pass custom_mask directly. Without test coverage for this new path—especially given that this change addresses a precision issue—the fix cannot be verified.
Add mask_indptr conversion when passing "packed_custom_mask" without "custom_mask" in BatchPrefillWithPagedKVCacheWrapper.plan()
📌 Description
When custom_mask is enabled, BatchPrefillWithPagedKVCacheWrapper.plan() can use packed_custom_mask over custom_mask to designate mask.
The
mask_indptrshould be 8-aligned and converted to uint8 (packed_custom_mask) pointer. This is correctly executed insegment_packbits(whenpacked_custom_mask is None and custom_mask is not None). However, when passing packed_custom_mask only, the mask_indptr is not converted.Not sure why this does not cause any illegal memory issue. In my case, it should be simply fixed by adding an additional check (The current commit), or adding parameter check to disable unwanted api usage.
🔍 Related Issues
should cause precision issue
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.