Skip to content

Conversation

@Dutch-voyage
Copy link

@Dutch-voyage Dutch-voyage commented Dec 11, 2025

Add mask_indptr conversion when passing "packed_custom_mask" without "custom_mask" in BatchPrefillWithPagedKVCacheWrapper.plan()

📌 Description

When custom_mask is enabled, BatchPrefillWithPagedKVCacheWrapper.plan() can use packed_custom_mask over custom_mask to designate mask.

image

The mask_indptr should be 8-aligned and converted to uint8 (packed_custom_mask) pointer. This is correctly executed in segment_packbits (when packed_custom_mask is None and custom_mask is not None). However, when passing packed_custom_mask only, the mask_indptr is not converted.

Not sure why this does not cause any illegal memory issue. In my case, it should be simply fixed by adding an additional check (The current commit), or adding parameter check to disable unwanted api usage.

🔍 Related Issues

should cause precision issue

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor
    • Consolidated mask-handling logic to improve consistency and maintainability across mask-related operations.
  • Bug Fix
    • Ensures packed masks are correctly interpreted when supplied on their own, preserving existing behavior when explicit masks are provided.

✏️ Tip: You can customize this high-level summary in your review settings.

… BatchPrefillWithPagedKVCacheWrapper.plan()
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 11, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Extracts indptr computation for packed masks into a new internal helper _get_indptr_for_packed_mask in flashinfer/quantization.py, and reuses it from flashinfer/prefill.py::plan() when packed_custom_mask is provided without custom_mask; also adds a TODO about consistency between custom_mask and packed_custom_mask.

Changes

Cohort / File(s) Summary
Internal helper extraction
flashinfer/quantization.py
Adds internal helper _get_indptr_for_packed_mask(indptr: torch.Tensor) -> torch.Tensor that computes per-segment packed lengths and cumulative indptr; segment_packbits now calls this helper instead of computing indptr_new inline.
Helper reuse in prefill
flashinfer/prefill.py
Imports _get_indptr_for_packed_mask and, in plan(), computes mask_indptr from packed_custom_mask when custom_mask is absent; retains previous behavior when custom_mask is provided; adds a TODO about consistency between custom_mask and packed_custom_mask.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Verify _get_indptr_for_packed_mask reproduces the exact indptr output previously produced inline in segment_packbits.
  • Check segment_packbits behavior (packed output and returned indptr) remains unchanged.
  • Validate plan() in prefill.py correctly handles both custom_mask and packed_custom_mask code paths and that the TODO is addressed by reviewers if necessary.

Poem

🐰 I hopped through bits and sums today,
Pulled indices out to tidy the way.
One helper now guides two clever paws,
Packing masks without the old cause.
A tiny hop, a neater play. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The description provides a clear explanation of the problem, the motivation, and the solution. It includes related information, a visual aid, and acknowledges the implementation choice.
Title check ✅ Passed The title clearly and specifically describes the main change: adding mask_indptr conversion for BatchPrefillWithPagedKVCacheWrapper.plan() when packed_custom_mask is provided.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Dutch-voyage, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a potential issue in the BatchPrefillWithPagedKVCacheWrapper.plan() method where the mask_indptr was not being properly converted when only a packed_custom_mask was provided. The change introduces a dedicated conversion step to ensure correct 8-alignment and uint8 pointer conversion for mask_indptr in this specific scenario, preventing potential precision issues and improving robustness. A new helper function was also extracted to streamline this logic and improve code maintainability.

Highlights

  • Fix mask_indptr conversion: Ensures mask_indptr is correctly converted (8-aligned and to uint8 pointer) when packed_custom_mask is used without custom_mask in BatchPrefillWithPagedKVCacheWrapper.plan().
  • Introduce _get_indptr_for_packed_mask: A new utility function was added to flashinfer/quantization.py to handle the calculation of indptr_new for packed masks, improving code reusability and clarity.
  • Refactor segment_packbits: The existing segment_packbits function now utilizes the new _get_indptr_for_packed_mask helper, centralizing the logic for indptr calculation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Dutch-voyage Dutch-voyage changed the title Add mask_indptr conversion when passing 'packed_custom_mask' only for… Add mask_indptr conversion in BatchPrefillWithPagedKVCacheWrapper.plan() Dec 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a new helper function, _get_indptr_for_packed_mask, in flashinfer/quantization.py to compute index pointers for packed masks, which refactors existing logic within segment_packbits. This new function is then imported and utilized in flashinfer/prefill.py to correctly handle mask_indptr when a packed_custom_mask is provided without a custom_mask, alongside a TODO comment regarding mask consistency. A review comment requested adding a return type hint to the newly defined _get_indptr_for_packed_mask function for better code clarity and maintainability.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)

1774-1791: Critical: Logic error when both masks are provided.

When a caller provides both custom_mask and packed_custom_mask, the code:

  1. Computes mask_indptr for unpacked segments (lines 1775-1780)
  2. Skips the conversion at line 1782 because custom_mask is None evaluates to False
  3. Skips the packing at line 1785 because packed_custom_mask is None evaluates to False
  4. Result: mask_indptr points to unpacked segment boundaries, but packed_custom_mask (which is packed) will be used by the attention kernel, causing misaligned memory access and incorrect masking

Apply this diff to fix the logic by either converting mask_indptr whenever packed_custom_mask is present, or validating mutual exclusivity:

Option 1 (preferred): Always convert mask_indptr when packed_custom_mask is provided

 if custom_mask is not None or packed_custom_mask is not None:
     mask_indptr = _compute_page_mask_indptr(
         qo_indptr,
         paged_kv_indptr,
         paged_kv_last_page_len,
         page_size,
     )
 
-if packed_custom_mask is not None and custom_mask is None:
+if packed_custom_mask is not None:
+    # packed_custom_mask is provided, convert mask_indptr to packed format
     mask_indptr = _get_indptr_for_packed_mask(mask_indptr)
-
-if packed_custom_mask is None and custom_mask is not None:
+elif custom_mask is not None:
     # create packed custom mask from custom mask
     packed_custom_mask, mask_indptr = segment_packbits(
         custom_mask.contiguous().view(-1),
         mask_indptr,
         bitorder="little",
     )

Option 2: Add parameter validation

+if custom_mask is not None and packed_custom_mask is not None:
+    raise ValueError(
+        "Cannot provide both custom_mask and packed_custom_mask. "
+        "Either provide custom_mask (will be packed internally) or "
+        "packed_custom_mask (pre-packed with packbits/segment_packbits)."
+    )
+
 if custom_mask is not None or packed_custom_mask is not None:
     mask_indptr = _compute_page_mask_indptr(
         qo_indptr,
         paged_kv_indptr,
         paged_kv_last_page_len,
         page_size,
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f1cb89 and 9a3a75b.

📒 Files selected for processing (2)
  • flashinfer/prefill.py (2 hunks)
  • flashinfer/quantization.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/prefill.py (1)
flashinfer/quantization.py (3)
  • packbits (47-78)
  • segment_packbits (90-144)
  • _get_indptr_for_packed_mask (80-87)
🔇 Additional comments (3)
flashinfer/quantization.py (2)

80-87: LGTM! Clean extraction of indptr conversion logic.

The helper correctly computes the packed indptr by calculating segment lengths, determining the packed byte count per segment, and using cumsum to build the new indptr.


136-136: LGTM! Good refactoring to use the extracted helper.

This change centralizes the indptr conversion logic and makes it reusable.

flashinfer/prefill.py (1)

39-39: LGTM! Import is correct.

The internal helper is properly imported for use in the plan() method.

self._num_qo_heads = num_qo_heads
self._num_kv_heads = num_kv_heads

# TODO: should we check consistency between custom_mask and packed_custom_mask?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Resolve the TODO by adding validation.

The TODO correctly identifies the need to check consistency between custom_mask and packed_custom_mask. As shown in the previous comment, allowing both to be provided leads to a critical bug. This TODO should be resolved by implementing either mutual exclusivity validation or preferential handling of packed_custom_mask.

🤖 Prompt for AI Agents
In flashinfer/prefill.py around line 1772, implement validation that prevents
inconsistent usage of custom_mask and packed_custom_mask: if both are provided
raise a ValueError stating they are mutually exclusive; if packed_custom_mask is
provided, unpack it and validate its shape/dtype matches expected mask shape and
values; if custom_mask is provided validate its shape/dtype and contents as
well; ensure any downstream code uses the validated mask object (or the unpacked
one) so behavior is deterministic.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/quantization.py (1)

80-87: Refactoring looks solid.

The extracted helper correctly computes the new indptr for packed segments using the ceiling division formula (seglen + 7) // 8. The logic is sound and appropriately reusable across modules.

Optionally, consider adding a brief docstring to clarify the helper's purpose:

 def _get_indptr_for_packed_mask(
     indptr: torch.Tensor,
 ) -> torch.Tensor:
+    """Compute indptr for packed (bit-packed into uint8) segments.
+    
+    Given an indptr array defining segment boundaries, returns a new indptr
+    where each segment length is the ceiling division by 8 (for bit-packing).
+    """
     seglen = indptr[1:] - indptr[:-1]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9a3a75b and b563db6.

📒 Files selected for processing (1)
  • flashinfer/quantization.py (2 hunks)
🔇 Additional comments (1)
flashinfer/quantization.py (1)

136-136: Clean delegation to the extracted helper.

The refactored code correctly uses _get_indptr_for_packed_mask and preserves the existing behavior of segment_packbits.

Comment on lines +80 to +87
def _get_indptr_for_packed_mask(
indptr: torch.Tensor,
) -> torch.Tensor:
seglen = indptr[1:] - indptr[:-1]
packed_len = (seglen + 7) // 8
indptr_new = torch.zeros(len(indptr), dtype=indptr.dtype, device=indptr.device)
indptr_new[1:] = torch.cumsum(packed_len, 0)
return indptr_new
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

find . -type f -name '*test*.py' | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 925


🏁 Script executed:

rg -l 'packed_custom_mask' --type=py

Repository: flashinfer-ai/flashinfer

Length of output: 160


🏁 Script executed:

rg -n 'packed_custom_mask' --type=py -A 3 -B 3

Repository: flashinfer-ai/flashinfer

Length of output: 20383


🏁 Script executed:

cat -n tests/utils/test_jit_example.py | sed -n '60,85p'

Repository: flashinfer-ai/flashinfer

Length of output: 1287


🏁 Script executed:

rg -n 'def test.*prefill' tests/attention/ --type=py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1977


🏁 Script executed:

rg -n 'packed_custom_mask' tests/attention/ --type=py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n 'plan\(' tests/ --type=py -A 3 | grep -A 3 'packed_custom_mask'

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n 'batch_prefill_with_paged_kv_cache\(' tests/ --type=py -A 5 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1959


🏁 Script executed:

# Check what parameters are actually passed in test files
rg -n 'packed_custom_mask' tests/ --type=py -B 10 -A 2 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1218


🏁 Script executed:

rg -n 'def test_batch_prefill_with_paged_kv_cache_custom_mask' tests/attention/test_batch_prefill_kernels.py -A 30

Repository: flashinfer-ai/flashinfer

Length of output: 1035


🏁 Script executed:

# Look at how custom_mask is used in the test
rg -n 'custom_mask' tests/attention/test_batch_prefill_kernels.py -B 2 -A 2 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 2465


🏁 Script executed:

sed -n '1770,1800p' flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 1355


🏁 Script executed:

sed -n '1193,1200p' flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 379


Add test coverage for the packed_custom_mask-only code path.

The new code at BatchPrefillWithPagedKVCacheWrapper.__init__ (line 1782-1783) handles the case where packed_custom_mask is provided without custom_mask, calling _get_indptr_for_packed_mask. Existing tests only pass custom_mask directly. Without test coverage for this new path—especially given that this change addresses a precision issue—the fix cannot be verified.

@Dutch-voyage Dutch-voyage changed the title Add mask_indptr conversion in BatchPrefillWithPagedKVCacheWrapper.plan() Fix: Add mask_indptr conversion in BatchPrefillWithPagedKVCacheWrapper.plan() Dec 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant