Skip to content

Conversation

@nvpohanh
Copy link
Contributor

@nvpohanh nvpohanh commented Nov 28, 2025

📌 Description

#2111 already enabled Hopper FA3 FP8 attention in prefill.py. This is just a follow-up PR to make the same change in decode.py because decode.py actually uses prefill kernels.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added selectable backend support (including a new backend option) and explicit output-dtype control for decode/prefill workflows.
  • Improvements

    • Improved FP8 handling and propagation of scales; runtime checks enforce output-dtype consistency and avoid unnecessary scaling when scale == 1.0.
    • Backend auto-selection logic enhanced to consider output dtype.
  • Documentation

    • FP8 guidance updated to allow float16 and bfloat16 outputs.
  • Tests

    • Added tests validating FP8 paged decoding with the new backend.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 28, 2025

📝 Walkthrough

Walkthrough

Adds an explicit output data-type parameter (o_data_type) and threads FP8 q/k/v scales through plan/run/prefill/decode paths, introduces backend-aware branching for fa2 vs fa3, validates cached output dtype at runtime, and adds FA3 FP8 paged-decode tests.

Changes

Cohort / File(s) Summary
Public API & decode logic
flashinfer/decode.py
Add o_data_type to BatchDecodeWithPagedKVCacheWrapper.plan(), cache _cached_o_data_type, use determine_attention_backend() for backend selection, and branch plan/run calls and argument lists for fa2 vs non-fa2; allocate/validate outputs using resolved o_data_type.
FP8 scale propagation & run paths
flashinfer/decode.py, flashinfer/prefill.py
Extract FP8 scale tensors (fp8_scale_q/k/v) from *args and forward to kernels; ensure output dtype follows o_data_type (defaulting to q dtype when None); apply v_scale only when it's a float != 1.0; runtime check in run() raises if provided out dtype mismatches cached o_data_type.
JIT attention module validation & tensors
flashinfer/jit/attention/modules.py
Add runtime validation in gen_batch_prefill_module (backend must be "fa2" or "fa3"; FP8 outputs disallowed); expand fa2 public tensor/scalar lists with additional maybe_* tensors and token_pos_in_items_len, updating dtypes.
Docs & docstrings
flashinfer/prefill.py
Update FP8 guidance to allow o_data_type as torch.float16 or torch.bfloat16 and reflect dtype-consistency checks in docstrings.
Tests & test utilities
tests/attention/test_hopper_fp8_attention.py, tests/utils/test_jit_example.py
Add test_batch_decode_paged() for FA3 FP8 paged KV decode (compare to FP16 reference), remove a debug print, and add backend="fa2" usage in a test helper to cover backend selection.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • cyx-6
  • nvmbreughe
  • aleozlx
  • bkryu
  • wenscarl
  • Anerudhan

"I hopped through bytes and dtype trees,
threading scales on tiny feets,
fa2 or fa3, I sniffed the breeze,
cached the output, passed the beats.
A carrot-coded test completes!" 🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and specifically describes the main objective of the PR: enabling Hopper FA3 FP8 attention in decode.py.
Description check ✅ Passed The description provides context linking to PR #2111 and explains the purpose, with checklist items marked complete. However, the Related Issues section is left empty despite the PR referencing PR #2111.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 28, 2025

please refer to #2111 where we refactored fa3 and exposed the fp8 interface to python

@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch 4 times, most recently from ede67a3 to a8d9e6a Compare December 2, 2025 03:55
@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch 2 times, most recently from ee77217 to 83cfce9 Compare December 8, 2025 08:44
@nvpohanh nvpohanh marked this pull request as ready for review December 11, 2025 06:17
@nvpohanh nvpohanh changed the title Enable Hopper FA3 FP8 attention Enable Hopper FA3 FP8 attention in decode.py Dec 11, 2025
@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch from 83cfce9 to 09a1ece Compare December 11, 2025 06:34
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
flashinfer/prefill.py (1)

2110-2114: Stronger out-dtype validation; consider tightening the error message

The explicit check that out.dtype matches the planned o_data_type enforces the plan/run contract and prevents silent dtype mismatches when callers reuse an out buffer. This is a solid safety improvement.

Ruff’s TRY003 warning about long exception messages could be addressed by shortening the message slightly or moving it into a shared constant/helper, but that’s stylistic and not functionally required.

flashinfer/decode.py (1)

1306-1313: Minor suggestion: align out shape check with allocation expression

Right now out is allocated with q.shape[:-1] + v_cache.shape[-1:] but validated against q.shape. Those are equal for today’s kernels (q and v share head_dim), but if q/v head dims ever diverge, the validation would become inconsistent. Consider using the same expression in both places for future-proofing.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dc0ade7 and 09a1ece.

📒 Files selected for processing (3)
  • flashinfer/decode.py (17 hunks)
  • flashinfer/jit/attention/modules.py (1 hunks)
  • flashinfer/prefill.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (5)
flashinfer/utils.py (3)
  • determine_attention_backend (450-492)
  • canonicalize_torch_dtype (241-249)
  • is_float8 (158-159)
flashinfer/logits_processor/types.py (2)
  • dtype (126-130)
  • device (119-123)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
flashinfer/attention.py (1)
  • plan (71-136)
flashinfer/pod.py (2)
  • plan (265-434)
  • plan (800-1014)
🪛 Ruff (0.14.8)
flashinfer/prefill.py

2112-2114: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (10)
flashinfer/jit/attention/modules.py (1)

987-992: Good preflight validation for backend and FP8 output dtype

The explicit backend whitelist and dtype_o FP8 guard make the FA2/FA3 path fail fast on unsupported configurations and align with the FP8 constraints used elsewhere in prefill/decode. Looks correct and consistent with the intended Hopper FA3 FP8 support.

flashinfer/prefill.py (3)

1704-1705: FP8 output dtype guidance in docstrings is clear and aligned

The added note that o_data_type for FP8 inputs should typically be torch.float16 or torch.bfloat16 matches the actual capabilities and helps users avoid unsupported FP8 outputs. No further changes needed here.

Also applies to: 2656-2657


2080-2082: q_scale documentation matches runtime usage

The new q_scale description correctly reflects how it’s applied (folded into sm_scale for FP8 BMM1). This keeps the public API understandable for FP8 users.


2268-2273: Conditional v_scale application is correct and avoids unnecessary work

Applying v_scale only when it’s not None and not 1.0 preserves previous behavior while saving a redundant multiply on the common default. The FP8 branch (cast to fp32, scale, cast back) is also consistent with typical scaling practice.

flashinfer/decode.py (6)

51-76: Backend auto-selection and FA2/FA3 plan args look consistent and safe

Importing and using determine_attention_backend to specialize self._backend when it is "auto" in the tensor-core path, and then branching FA2-only plan arguments (fixed_split_size, disable_split_kv, num_colocated_ctas) while keeping FA3 with the shorter signature, matches the expected separation of FA2/FA3 interfaces and mirrors the logic in fast_decode_plan. This keeps decode aligned with prefill and should enable FA3 on Hopper cleanly.

Also applies to: 1042-1050, 2635-2660


720-733: Passing backend into gen_customize_batch_prefill_module is the right direction

Wiring backend explicitly into gen_customize_batch_prefill_module for the tensor-core JIT path aligns decode with the prefill side and makes backend selection explicit at module generation time. Assuming the JIT generator’s signature matches this order, this is a straightforward and correct extension of the existing JIT flow.


838-840: o_data_type threading and validation are coherent

The new o_data_type parameter is:

  • Defaulted to q_data_type when not provided, then canonicalized.
  • Cached as _cached_o_data_type and threaded into get_trtllm_gen_decode_module, get_batch_prefill_module, and get_batch_decode_module.
  • Used at run time for both allocation and check_shape_dtype_device validation of out.

This is consistent with the rest of the dtype handling and gives callers explicit control over output dtype (including FP8-input → FP16/BF16-output scenarios) without breaking older call sites that omit o_data_type.

Also applies to: 886-889, 964-977, 985-988, 1025-1035, 1095-1104, 1306-1313


1341-1361: FP8 scale wiring and v_scale optimization look correct

Extracting fp8_scale_q/k/v from *args only when q is float8 and only for the non-JIT tensor-core path keeps the existing API surface intact while enabling FA3/FA2 FP8 usage. Passing these explicitly into paged_run matches the extended kernel signature, and the updated v_scale guard (v_scale is not None and v_scale != 1.0) plus is_float8(out)-based cast behavior are sensible micro-optimizations that preserve numerical behavior.

Also applies to: 1413-1418


998-1005: trtllm-gen decode integration with o_data_type remains consistent

Adding o_data_type to the get_trtllm_gen_decode_module cache key while still using the same paged_run argument layout (including workspace_size, block_tables, kv_lens_buffer, etc.) keeps the trtllm-gen path coherent with FA2/FA3 and ensures different output dtypes don’t collide in the module cache. The subsequent _paged_run call continues to receive the expected scales and workspace sizing.

Also applies to: 1364-1373


2523-2590: fast_decode_plan argument construction now mirrors main plan

The tensor-core branch in fast_decode_plan now builds the same base args list as BatchDecodeWithPagedKVCacheWrapper.plan and appends FA2-only arguments under self._backend == "fa2". This keeps the “fast” path in sync with the standard planner and reduces the risk of FA3/FA2 divergence for multistep decode.

Also applies to: 2635-2660

@nvpohanh
Copy link
Contributor Author

@yzh119 Could you review this PR? Thanks!

Let me know if you think I should add some tests for this. If so, please point me to the test file where I should add/extend the tests. Thanks!

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nvpohanh, I'd say the appropriate place to put unit tests should be test_hopper.py or test_hopper_fp8_attention.py.

Since we are adding fa3 as a backend for the first time in BatchDecodeWithPagedKVCacheWrapper, we may need to write new tests analogous to the prefill ones in there.

@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch from 09a1ece to 1bb22d5 Compare December 17, 2025 13:29
@nvpohanh nvpohanh requested a review from jimmyzho as a code owner December 17, 2025 13:29
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)

2279-2284: Fix tensor comparison in v_scale condition.

The condition v_scale != 1.0 will fail when v_scale is a torch.Tensor with multiple elements (e.g., per-head scaling). Tensor comparisons return boolean tensors, which cannot be used directly in if statements and will raise: RuntimeError: The truth value of a Tensor with more than one element is ambiguous.

The v_scale parameter accepts Optional[Union[float, torch.Tensor]] per the type annotation, and the codebase documentation shows per-head tensor scales with shape [num_kv_heads] are supported in other functions. Additionally, q_scale and k_scale (same type) are handled differently at lines 2140-2143 without this comparison.

Apply this fix to handle both scalar and tensor v_scale correctly:

-            if v_scale is not None and v_scale != 1.0:
+            if v_scale is not None:
+                # Skip scaling if v_scale is scalar 1.0
+                if isinstance(v_scale, (int, float)) and v_scale == 1.0:
+                    pass
+                else:
                 # TODO(Zihao): fused into kernel
                 if is_float8(out):
                     out = (out.to(torch.float32) * v_scale).to(out.dtype)
                 else:
                     out *= v_scale

Alternatively, add explicit validation in the function to document whether v_scale must be scalar if a tensor.

🧹 Nitpick comments (2)
flashinfer/jit/attention/modules.py (1)

987-993: Consider using explicit ValueError instead of assert for input validation.

While assertions work for catching invalid inputs during development, they can be disabled with python -O in production environments. For input validation that should always run, explicit if statements with ValueError are more reliable:

-    assert backend in ["fa2", "fa3"], (
-        f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}"
-    )
-    assert dtype_o not in [torch.float8_e4m3fn, torch.float8_e5m2], (
-        "FP8 output is not supported in fa2/fa3 backends yet"
-    )
+    if backend not in ["fa2", "fa3"]:
+        raise ValueError(
+            f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}"
+        )
+    if dtype_o in [torch.float8_e4m3fn, torch.float8_e5m2]:
+        raise ValueError(
+            "FP8 output is not supported in fa2/fa3 backends yet"
+        )
flashinfer/decode.py (1)

709-712: Clarify backend parameter documentation to mention FA3.

The past review comment by bkryu is still relevant. While the backend parameter accepts "auto", "fa2", or "trtllm-gen", the docstring should clarify that "auto" may internally select FA3 on supported hardware (as implemented in lines 1042-1050). This would help users understand the full range of backends that may be used.

Apply this diff to improve the documentation:

         backend : str
-            The implementation backend, could be ``auto``/``fa2`` or ``trtllm-gen``. Defaults to ``auto``.
-            If set to ``auto``, the wrapper will automatically choose the backend based on the
-            device architecture and kernel availability.
+            The implementation backend, could be ``auto``/``fa2`` or ``trtllm-gen``. Defaults to ``auto``.
+            If set to ``auto``, the wrapper will automatically choose the backend based on the
+            device architecture and kernel availability (may select FA2 or FA3 for tensor cores).
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 09a1ece and 1bb22d5.

📒 Files selected for processing (3)
  • flashinfer/decode.py (17 hunks)
  • flashinfer/jit/attention/modules.py (1 hunks)
  • flashinfer/prefill.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (1)
flashinfer/utils.py (4)
  • determine_attention_backend (455-497)
  • canonicalize_torch_dtype (246-254)
  • PosEncodingMode (32-35)
  • is_float8 (158-159)
🪛 Ruff (0.14.8)
flashinfer/prefill.py

2118-2120: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (16)
flashinfer/jit/attention/modules.py (1)

994-1011: LGTM! Expanded FA2 backend tensor support.

The addition of new tensor names (mask_indptr, prefix_len_ptr, token_pos_in_items_ptr, max_item_len_ptr) with their corresponding dtypes appropriately extends the FA2 backend capabilities to support the FP8 attention features.

flashinfer/prefill.py (3)

1705-1705: LGTM! Helpful FP8 guidance added to docstrings.

The added guidance for o_data_type with FP8 inputs is clear and informative, helping users understand that they should typically use torch.float16 or torch.bfloat16 as the output dtype when working with FP8 attention.


2116-2121: LGTM! Important runtime dtype validation.

The validation ensures consistency between the output tensor's dtype and the o_data_type specified during planning, preventing potential dtype mismatches that could cause runtime errors or incorrect results.

Note: The static analysis hint (TRY003) about message length is a minor style preference. The current message provides helpful diagnostic information and is acceptable.


2667-2667: LGTM! Consistent FP8 guidance.

Same helpful FP8 guidance as added to the paged variant's docstring (line 1705), ensuring consistent documentation across both BatchPrefillWithPagedKVCacheWrapper and BatchPrefillWithRaggedKVCacheWrapper.

flashinfer/decode.py (12)

66-66: LGTM: Import necessary for auto backend selection.

The determine_attention_backend import is required for the auto backend selection logic introduced later in the plan method (lines 1042-1050).


838-838: LGTM: o_data_type parameter addition aligns with FP8 support.

The addition of the o_data_type parameter and its documentation clearly supports FP8 attention workflows, where output dtype often differs from input dtype. The default behavior (falling back to q_data_type) maintains backward compatibility.

Also applies to: 886-888


974-977: LGTM: o_data_type handling is consistent with existing patterns.

The canonicalization and caching of o_data_type follows the same pattern as q_data_type and kv_data_type, ensuring consistency across the codebase.

Also applies to: 987-987


1024-1036: LGTM: o_data_type properly threaded through all backend paths.

The o_data_type parameter is correctly passed to all three module creation functions (get_trtllm_gen_decode_module, get_batch_prefill_module, and get_batch_decode_module), ensuring consistent output dtype handling across FA2, FA3, and TRTLLM-gen backends.

Also applies to: 1051-1063, 1094-1104


1042-1050: LGTM: Auto backend selection correctly implemented.

The auto backend selection logic properly delegates to determine_attention_backend with appropriate parameters. The backend is selected based on device capabilities and input dtypes, which is the correct approach since backend capability is primarily determined by what inputs it can process.


1065-1089: LGTM: Conditional plan arguments correctly handle FA2/FA3 differences.

The conditional logic properly handles the different plan signatures for FA2 (19 arguments) and FA3 (16 arguments) backends. FA2-specific parameters (fixed_split_size, disable_split_kv, num_colocated_ctas) are only appended when using the FA2 backend.

Note: When jit_module is provided (line 1039), the auto backend selection at lines 1042-1050 is skipped. Ensure that when using custom JIT modules with backend="auto", users are aware that they need to specify the correct backend explicitly.


1306-1313: LGTM: Output tensor handling correctly uses o_data_type.

The output allocation and validation logic properly uses _cached_o_data_type with a sensible fallback to q.dtype. This ensures that FP8 workflows can produce outputs in the desired dtype (typically FP16/BF16) while maintaining backward compatibility for existing code.


1341-1363: LGTM: FP8 scale tensor extraction enables FP8 attention support.

The extraction and propagation of FP8 scale tensors (fp8_scale_q, fp8_scale_k, fp8_scale_v) from *args correctly enables FP8 attention workflows. The conditional extraction (only when is_float8(q) and sufficient args are provided) is appropriate and aligns with the PR objective.


1413-1418: LGTM: v_scale optimization avoids unnecessary computation.

The refined condition if v_scale is not None and v_scale != 1.0 appropriately skips the no-op scaling when v_scale is 1.0, improving performance while maintaining correctness.


2635-2660: LGTM: fast_decode_plan correctly mirrors regular plan logic.

The conditional FA2/FA3 argument handling in fast_decode_plan correctly mirrors the logic in the main plan method (lines 1065-1089), ensuring consistency between the fast and standard plan implementations. The comment at line 2635 clearly documents the argument count difference.


1341-1363: Verify FP8 scale tensor ordering consistency with prefill.py implementation.

Confirm that the FP8 scale tensor extraction and ordering (lines 1345-1347: fp8_scale_q = args[0], fp8_scale_k = args[1], fp8_scale_v = args[2]) and their positions in run_args (lines 1358-1360) match the prefill.py implementation. This ensures consistency between decode and prefill operations for maintainability and correctness.


720-732: No backend parameter needed for decode module generation.

After examining the function implementations, gen_customize_batch_decode_module intentionally does not accept a backend parameter because it has only a single implementation path without backend-specific variations. In contrast, gen_customize_batch_prefill_module accepts backend because it includes conditional logic for multiple backend implementations (fa2, fa3, etc.). The code correctly passes backend only to the prefill module. This is by design, not an inconsistency.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/utils/test_jit_example.py (1)

254-260: Explicit backend specification improves test clarity.

The addition of backend="fa2" makes the backend selection explicit, which aligns well with the wrapper configuration and the SM80 variant declaration used in this test.

However, consider adding a corresponding SM90/FA3 test for batch decode (similar to test_batch_prefill_sm90_flash_sigmoid at line 459) to validate the PR's main objective of enabling Hopper FA3 FP8 attention in decode.py. Currently, decode operations only have FA2 test coverage, while prefill has both FA2 and FA3 tests.

Would you like me to help generate a test_batch_decode_sm90_flash_sigmoid test that uses the FA3 backend with the SM90 variant declaration to fully validate the FA3 FP8 functionality in decode.py?

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1bb22d5 and 4192c18.

📒 Files selected for processing (1)
  • tests/utils/test_jit_example.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch 2 times, most recently from 7c59a4c to f841c9c Compare December 18, 2025 05:24
@nvpohanh
Copy link
Contributor Author

added a test to test_hopper_fp8_attention.py

@nvpohanh nvpohanh requested a review from bkryu December 18, 2025 05:25
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/attention/test_hopper_fp8_attention.py (1)

1-1: Critical: Fix ruff-format errors before merge

The pre-commit ruff-format hook failed, indicating that the file needs formatting. Please run the following command locally to fix formatting issues:

pre-commit run ruff-format --files tests/attention/test_hopper_fp8_attention.py

Or format all files:

pre-commit run --all-files
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4192c18 and 7c59a4c.

📒 Files selected for processing (4)
  • flashinfer/decode.py (18 hunks)
  • flashinfer/jit/attention/modules.py (1 hunks)
  • flashinfer/prefill.py (4 hunks)
  • tests/attention/test_hopper_fp8_attention.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/attention/test_hopper_fp8_attention.py (4)
flashinfer/utils.py (1)
  • is_sm90a_supported (531-533)
flashinfer/decode.py (9)
  • BatchDecodeWithPagedKVCacheWrapper (593-1456)
  • use_tensor_cores (792-793)
  • use_tensor_cores (1623-1624)
  • plan (824-1128)
  • plan (1651-1764)
  • run (1158-1171)
  • run (1174-1187)
  • run (1190-1420)
  • run (1767-1891)
flashinfer/prefill.py (8)
  • plan (1595-1988)
  • plan (2570-2867)
  • run (2019-2031)
  • run (2034-2046)
  • run (2049-2285)
  • run (2897-2907)
  • run (2910-2920)
  • run (2923-3098)
benchmarks/bench_hopper_fp8_attention.py (1)
  • per_head_symmetric_quant (27-38)
flashinfer/decode.py (1)
flashinfer/utils.py (5)
  • determine_attention_backend (455-497)
  • canonicalize_torch_dtype (246-254)
  • PosEncodingMode (32-35)
  • check_shape_dtype_device (565-583)
  • is_float8 (158-159)
flashinfer/prefill.py (1)
flashinfer/logits_processor/types.py (1)
  • dtype (126-130)
🪛 GitHub Actions: pre-commit
tests/attention/test_hopper_fp8_attention.py

[error] 1-1: pre-commit ruff-format hook failed: 1 file reformatted. Run 'pre-commit run --all-files' locally to apply formatting changes.

🪛 Ruff (0.14.8)
flashinfer/prefill.py

2118-2120: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (20)
flashinfer/jit/attention/modules.py (1)

987-992: Validation checks look good.

The backend and output dtype validations are well-placed at the function entry point, ensuring invalid configurations are caught early. The error messages are clear and informative.

flashinfer/prefill.py (4)

1705-1705: Good documentation improvement for FP8 attention.

The guidance clarifies that FP8 inputs should typically use float16 or bfloat16 output, which aligns with the PR's FP8 attention support objectives.


2116-2120: Good runtime validation for output dtype consistency.

This check ensures that when a user provides an output tensor, its dtype matches what was specified in the plan() call, preventing subtle bugs. The error message is clear and includes the actual dtype values for easy debugging.

Note: The static analysis tool flags a style preference (TRY003) about the exception message length, but this is a minor nitpick and the current implementation prioritizes clarity.


2667-2667: Consistent documentation improvement.

Same helpful guidance as line 1705, applied consistently to the ragged KV cache wrapper's plan method.


2279-2284: The code is correct. The actual implementation uses isinstance(v_scale, float) to guard the comparison, properly handling the case where v_scale can be either a float or a Tensor. The review comment's code snippet does not match the repository.

Likely an incorrect or invalid review comment.

flashinfer/decode.py (10)

66-66: LGTM: FA3 backend support additions

The import of determine_attention_backend, docstring update mentioning FA3, and passing the backend parameter to the customized module are all correct and necessary for FA3 support.

Also applies to: 710-710, 725-725


838-838: LGTM: Output data type parameter added to public API

The addition of o_data_type parameter to the plan() signature is well-documented and follows the same pattern as kv_data_type. The documentation correctly indicates that for FP8 inputs, the output should typically be torch.float16 or torch.bfloat16.

Also applies to: 886-889


974-977: LGTM: Output data type canonicalization and caching

The canonicalization and defaulting logic for o_data_type follows the same pattern as kv_data_type, and caching it in _cached_o_data_type is consistent with the existing caching pattern.

Also applies to: 987-987


1042-1050: LGTM: Backend auto-selection logic

The automatic backend determination for use_tensor_cores=True paths correctly uses determine_attention_backend to choose between FA2 and FA3 based on device capabilities and data types. The non-tensor-cores path doesn't require backend selection, so this implementation is correct.


1027-1027: LGTM: Output data type propagation to modules

The o_data_type parameter is correctly threaded through to all three module types (trtllm-gen, batch prefill, and batch decode), ensuring that the output data type is properly configured across different backends.

Also applies to: 1052-1052, 1055-1055, 1097-1097


1065-1089: LGTM: Backend-specific plan argument handling

The conditional argument handling correctly differentiates between FA2 (which requires fixed_split_size, disable_split_kv, and num_colocated_ctas) and FA3 (which doesn't support these parameters). The base argument list has 16 parameters, with FA2 adding 3 more for a total of 19.


1307-1313: LGTM: Output tensor allocation with cached data type

The output tensor allocation and validation correctly use _cached_o_data_type (with safe fallback to q.dtype) to ensure the output tensor has the planned data type. This is consistent with the pattern used in prefill.py.


1341-1360: LGTM: FP8 scale tensor extraction for FA3 path

The extraction of FP8 scale tensors from *args when q is FP8 is correctly guarded by is_float8(q) and len(args) >= 3 to avoid indexing errors. The scales default to None if not provided, which the kernel should handle appropriately.


1413-1413: LGTM: Optimize v_scale application

The additional check v_scale != 1.0 correctly skips the scaling operation when no scaling is needed, avoiding unnecessary work. This optimization aligns with the similar change in prefill.py.


2635-2660: LGTM: Consistent FA2/FA3 argument handling in fast_decode_plan

The fast_decode_plan function correctly mirrors the backend-specific argument handling from the main plan method, ensuring FA2 receives its three additional parameters while FA3 uses only the base 16 arguments.

tests/attention/test_hopper_fp8_attention.py (5)

667-691: LGTM: Test function setup

The test parameters and setup are appropriate for validating FP8 batch decode with paged KV cache. The GQA scenarios (different num_qo_heads and num_kv_heads combinations) are particularly valuable for testing head mapping logic.


693-736: LGTM: Paged KV cache construction

The test correctly constructs paged KV cache with per-head varying scales (using create_per_head_varying_kv), which helps reveal head offset bugs. The query shape [batch_size, num_qo_heads, head_dim] is appropriate for decode (no sequence length dimension).


738-754: LGTM: FP16 reference computation

The FP16 reference computation correctly uses BatchDecodeWithPagedKVCacheWrapper with use_tensor_cores=True and backend="fa3", which is appropriate since decode uses prefill kernels under the hood when tensor cores are enabled.


786-788: Verify MSE threshold is intentional

The MSE threshold is 0.01, which is 100× stricter than the 1.0 threshold used in other FP8 tests (e.g., test_batch_prefill_paged at line 524). This stricter threshold might be intentional since decode has simpler computation (single query per request) with less accumulation error. However, please verify this is the intended threshold and not a typo.


951-959: LGTM: Test execution added to main block

The new decode test is correctly integrated into the __main__ execution block with a reasonable subset of parameters for quick local validation.

@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch 2 times, most recently from e3e9279 to 6ecd187 Compare December 18, 2025 05:28
@nvpohanh
Copy link
Contributor Author

@bkryu could you review again? thanks!

@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch from 6ecd187 to 92c42b1 Compare December 22, 2025 01:44
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tests/attention/test_hopper_fp8_attention.py (1)

667-785: Decode FP8 paged‑KV GQA test is well‑structured; consider adding e5m2 coverage

The new test_batch_decode_paged mirrors the FP8 paged‑prefill GQA tests, exercises BatchDecodeWithPagedKVCacheWrapper under FA3 with per‑head‑varying KV, and uses a tight MSE threshold — this should reliably catch head‑offset or backend wiring issues in the new decode path.

You might optionally extend dtype to include torch.float8_e5m2 here as well, to mirror the broader FP8 coverage used in the prefill tests.

flashinfer/prefill.py (1)

2116-2120: o_data_type caching and v_scale handling align plan/run behavior for FP8 paths

  • Using self._cached_o_data_type both to validate a user‑provided out tensor and to choose the dtype when allocating out ensures consistency with the o_data_type passed to plan(). This is especially important now that FP8 inputs often request FP16/BF16 outputs.
  • The explicit check that out.dtype matches the planned o_data_type will fail fast on configuration mistakes instead of silently running kernels with a mismatched dtype.
  • Applying v_scale only when it is a Python float (and not equal to 1.0) prevents accidental double‑scaling when FP8 per‑head/per‑tensor scales are passed as tensors; tensor scales are already consumed inside the kernel via the JIT parameters.

If you want to appease Ruff’s TRY003 warning, you could shorten or slightly rephrase the ValueError message, but that’s stylistic rather than functional.

Also applies to: 2158-2168, 2279-2284

flashinfer/decode.py (1)

2635-2660: Consider adding o_data_type support to fast_decode_plan.

The backend-specific argument handling is correctly updated to match the plan method. However, fast_decode_plan doesn't include the o_data_type parameter that was added to plan. If this function needs to support FP8 outputs in the future, consider adding the o_data_type parameter.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7c59a4c and 92c42b1.

📒 Files selected for processing (5)
  • flashinfer/decode.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/prefill.py
  • tests/attention/test_hopper_fp8_attention.py
  • tests/utils/test_jit_example.py
🧰 Additional context used
🧬 Code graph analysis (3)
tests/attention/test_hopper_fp8_attention.py (3)
flashinfer/utils.py (1)
  • is_sm90a_supported (531-533)
flashinfer/decode.py (7)
  • BatchDecodeWithPagedKVCacheWrapper (593-1456)
  • plan (824-1128)
  • plan (1651-1764)
  • run (1158-1171)
  • run (1174-1187)
  • run (1190-1420)
  • run (1767-1891)
benchmarks/bench_hopper_fp8_attention.py (1)
  • per_head_symmetric_quant (27-38)
flashinfer/prefill.py (1)
flashinfer/logits_processor/types.py (1)
  • dtype (126-130)
flashinfer/decode.py (1)
flashinfer/utils.py (5)
  • determine_attention_backend (455-497)
  • canonicalize_torch_dtype (246-254)
  • PosEncodingMode (32-35)
  • check_shape_dtype_device (565-583)
  • is_float8 (158-159)
🪛 Ruff (0.14.8)
flashinfer/prefill.py

2118-2120: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (13)
tests/utils/test_jit_example.py (1)

254-260: Explicit fa2 backend looks correct here

Pinning backend="fa2" for the SM80 flash‑sigmoid JIT module avoids auto accidentally picking fa3 on newer GPUs and makes the test behavior deterministic across architectures.

flashinfer/jit/attention/modules.py (1)

985-993: Backend / dtype guards are consistent with FA2/FA3 FP8 design

The new assertions restricting backend to {"fa2","fa3"} and rejecting FP8 dtype_o align with how get_batch_prefill_module is used and with the FP8 tests (FP8 inputs, non‑FP8 outputs). This should prevent misconfigured JIT builds from other paths.

flashinfer/decode.py (11)

66-66: LGTM: Import added for backend auto-selection.

The determine_attention_backend import is necessary for the new auto-selection logic added in lines 1042-1050.


710-710: LGTM: Documentation updated to reflect FA3 backend support.

The docstring correctly documents that fa3 is now a supported backend option, aligning with the PR objective.


838-838: LGTM: Output data type parameter added with clear documentation.

The o_data_type parameter is properly documented and follows the same pattern as q_data_type and kv_data_type, with helpful guidance for FP8 usage scenarios.

Also applies to: 886-889


974-977: LGTM: Data type canonicalization follows established patterns.

The o_data_type canonicalization and caching logic is consistent with the existing handling of q_data_type and kv_data_type, with appropriate defaulting to q_data_type when not specified.

Also applies to: 987-987


1042-1050: LGTM: Backend auto-selection properly implemented.

The auto-selection logic correctly uses determine_attention_backend to choose between FA2 and FA3 based on device capabilities and configuration, with appropriate parameters passed.


1027-1027: LGTM: Output data type consistently threaded through module creation.

The o_data_type parameter is correctly passed to all module creation functions across different code paths (trtllm-gen, tensor cores with batch prefill, and standard batch decode), ensuring consistent output type handling.

Also applies to: 1051-1063, 1094-1097


1065-1089: LGTM: Backend-specific plan arguments properly handled.

The code correctly handles FA2-specific parameters (fixed_split_size, disable_split_kv, num_colocated_ctas) by only appending them when backend is "fa2", avoiding passing unsupported arguments to FA3 backend.


1307-1313: LGTM: Output tensor handling respects planned data type.

The output tensor creation and validation correctly uses _cached_o_data_type when available (set during plan), with appropriate fallback to q.dtype for backward compatibility.


725-725: LGTM: Backend parameter threaded through JIT module creation.

The backend parameter is correctly passed to gen_customize_batch_prefill_module when using custom JIT arguments, ensuring consistent backend selection across all code paths.


1413-1418: The optimization is safe and correctly constrained to float scales.

The condition change from if v_scale is not None to if isinstance(v_scale, float) and v_scale != 1.0 is valid because v_scale is defined as Optional[float] in the public API. This optimization avoids unnecessary scaling operations when v_scale equals 1.0 (identity scaling). Tensor v_scale values are not supported by the API, so the isinstance check is appropriate and safe. No changes needed.


1341-1360: Ensure FP8 scale handling is consistent across both tensor-core and non-tensor-core code paths.

FP8 scale tensors (fp8_scale_q, fp8_scale_k, fp8_scale_v) are extracted from *args and passed to the kernel for the tensor-core path (lines 1341-1348), but the non-tensor-core path (lines 1404-1410) omits this extraction and passing. If FP8 is intended to be supported in both paths, add the same extraction and passing logic to the non-tensor-core branch; otherwise, document why FP8 is tensor-core-only.

@nvpohanh
Copy link
Contributor Author

@bkryu could you review again? thanks!

@bkryu
Copy link
Collaborator

bkryu commented Dec 22, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !211 has been created, and the CI pipeline #40635547 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvpohanh, the Hopper unit tests are passing but the non-sm90 arch unit tests are failing. See the UT pipeline's #40635547 Blackwell jobs -- they are failing test_trtllm_gen_attention.py.

I suspect that there is some unintended change to the kernel launch or logic, or may be something trivial as needing a rebase. Can you see if you can fix the unit test files? Afterwards, the PR seems good to go as it passes on Hopper.

@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch from 92c42b1 to 121a73e Compare December 23, 2025 06:48
@nvpohanh
Copy link
Contributor Author

The CI failure is indeed caused by my change. I have pushed a fix now. Let's see how the new CI goes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
flashinfer/prefill.py (1)

2117-2120: New out dtype check against planned o_data_type is a good safety net

Validating out.dtype against the cached o_data_type from plan() will catch mismatches early and avoids harder-to-diagnose JIT/kernel failures later. The error message is explicit; if you care about Ruff’s TRY003 warning, you could trim or factor the message, but it’s fine functionally as-is.

flashinfer/decode.py (3)

1065-1089: Consider adding an explanatory comment for backend branching.

The backend-dependent argument list construction is correct—FA2 requires 19 plan arguments (including fixed_split_size, disable_split_kv, num_colocated_ctas) while FA3 requires only 16. However, a brief inline comment explaining this difference would improve maintainability.

Suggested comment
             args = [
                 self._float_workspace_buffer,
                 # ... (base args)
                 window_left,
             ]
+            # FA2 backend requires additional split-kv configuration parameters
             if self._backend == "fa2":
                 args.append(fixed_split_size)
                 args.append(disable_split_kv)
                 args.append(0)  # num_colocated_ctas
             self._plan_info = self._cached_module.plan(
                 *args,
             )

1307-1313: Consider validating that plan() was called before run().

The output tensor creation correctly uses _cached_o_data_type with a fallback to q.dtype, but if plan() was never called, the fallback might produce unexpected output dtypes (especially for FP8 inputs where output should typically be float16/bfloat16).

Consider adding an assertion or warning if _cached_o_data_type is not set, to catch incorrect API usage early.

Example validation
+        if not hasattr(self, "_cached_o_data_type"):
+            raise RuntimeError(
+                "plan() must be called before run(). "
+                "The output data type has not been initialized."
+            )
         if out is None:
-            out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
+            out_dtype = self._cached_o_data_type
             out = torch.empty(
                 q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device
             )
         else:
-            out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
+            out_dtype = self._cached_o_data_type
             check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out")

1341-1360: Consider validating FP8 scale tensor types.

The FP8 scale extraction assumes that when q is FP8 and len(args) >= 3, the first three args are scale tensors. While this follows the established convention, adding type or shape validation would improve robustness and catch incorrect API usage.

Example validation
                 # Extract FP8 scale tensors from *args if q is FP8
                 fp8_scale_q = None
                 fp8_scale_k = None
                 fp8_scale_v = None
                 if is_float8(q) and len(args) >= 3:
                     fp8_scale_q = args[0]
                     fp8_scale_k = args[1]
                     fp8_scale_v = args[2]
+                    # Validate that extracted scales are tensors (optional but safer)
+                    for scale, name in [(fp8_scale_q, "scale_q"), (fp8_scale_k, "scale_k"), (fp8_scale_v, "scale_v")]:
+                        if scale is not None and not isinstance(scale, torch.Tensor):
+                            raise TypeError(f"FP8 {name} must be a torch.Tensor, got {type(scale)}")
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 92c42b1 and 121a73e.

📒 Files selected for processing (5)
  • flashinfer/decode.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/prefill.py
  • tests/attention/test_hopper_fp8_attention.py
  • tests/utils/test_jit_example.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/utils/test_jit_example.py
🧰 Additional context used
🧬 Code graph analysis (3)
tests/attention/test_hopper_fp8_attention.py (3)
flashinfer/utils.py (1)
  • is_sm90a_supported (531-533)
flashinfer/decode.py (7)
  • BatchDecodeWithPagedKVCacheWrapper (593-1458)
  • plan (824-1128)
  • plan (1653-1766)
  • run (1158-1171)
  • run (1174-1187)
  • run (1190-1422)
  • run (1769-1893)
benchmarks/bench_hopper_fp8_attention.py (1)
  • per_head_symmetric_quant (27-38)
flashinfer/decode.py (1)
flashinfer/utils.py (4)
  • determine_attention_backend (455-497)
  • canonicalize_torch_dtype (246-254)
  • PosEncodingMode (32-35)
  • is_float8 (158-159)
flashinfer/prefill.py (2)
flashinfer/logits_processor/types.py (1)
  • dtype (126-130)
flashinfer/utils.py (1)
  • is_float8 (158-159)
🪛 Ruff (0.14.10)
flashinfer/prefill.py

2118-2120: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (11)
tests/attention/test_hopper_fp8_attention.py (2)

667-785: FA3 FP8 decode paged test wiring looks correct and consistent with prefill tests

The new test_batch_decode_paged mirrors the prefill-paged FP8 setup (indptr/indices, paged KV layout, per-head quantization, and GQA head configs) and exercises BatchDecodeWithPagedKVCacheWrapper with FA3+FP8 end‑to‑end. Shapes and arguments to plan/run look consistent with the existing prefill tests; the tighter MSE threshold is reasonable here.


948-956: Manual __main__ decode loop is consistent with existing debug harness

The new __main__ loop for test_batch_decode_paged follows the same pattern as the prefill/debug loops and is fine as a local debugging aid.

flashinfer/jit/attention/modules.py (1)

987-993: Backend and FP8-output guards in gen_batch_prefill_module are appropriate

The new asserts cleanly constrain backend to fa2/fa3 and forbid FP8 dtype_o, which matches current backend capabilities and will fail fast on misconfiguration without affecting valid callers.

flashinfer/prefill.py (2)

1705-1705: FP8 o_data_type guidance in docs is accurate and helpful

The added note that FP8 inputs should typically use fp16 or bf16 o_data_type matches how the kernels and tests are wired and makes the API expectations clearer.

Also applies to: 2671-2671


2280-2287: Refined v_scale post-scaling avoids redundant work for the common 1.0 case

Conditioning the post-kernel multiply on v_scale being not-None and not the scalar 1.0 preserves existing semantics while skipping unnecessary computation and dtype casts in the default case; tensor scales are still fully supported.

flashinfer/decode.py (6)

66-66: LGTM: Import addition supports FA3 backend selection.

The determine_attention_backend import is correctly added and used later for automatic backend selection based on device capabilities.


838-838: LGTM: Output data type parameter properly threaded through the API.

The o_data_type parameter is well-documented, correctly defaulted to q_data_type, and consistently propagated through module creation paths. The canonicalization logic follows established patterns for other data type parameters.

Also applies to: 886-889, 974-977, 987-987


1042-1050: LGTM: Backend auto-selection properly implemented.

The automatic backend determination correctly delegates to determine_attention_backend with appropriate parameters, enabling dynamic selection between FA2 and FA3 based on device capabilities and dtype constraints.


1414-1420: LGTM: v_scale optimization correctly skips redundant scaling.

The check isinstance(v_scale, float) and v_scale == 1.0 appropriately skips the scaling operation when it would have no effect. The FP8 output handling correctly converts to float32, scales, and converts back to maintain precision.


2664-2689: LGTM: fast_decode_plan correctly mirrors backend-specific argument handling.

The backend-dependent plan argument construction in fast_decode_plan is consistent with the main plan() method, correctly handling the FA2 vs FA3 signature differences. The explicit comment about argument counts is helpful.


710-710: LGTM: Backend documentation and parameter passing are correct.

The documentation properly lists fa3 as a supported backend alongside fa2 and trtllm-gen. The backend parameter is correctly threaded through to the customized module when JIT args are provided.

Also applies to: 725-725

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #40635547: 5/20 passed

@nvpohanh
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@nvpohanh is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@yongwww
Copy link
Member

yongwww commented Dec 23, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !211 has been updated with latest changes, and the CI pipeline #40698807 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can confirm that the unit tests failures are unrelated to current PR

Changes LGTM.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #40698807: 10/20 passed

@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch from 121a73e to 1477978 Compare December 30, 2025 02:38
@nvpohanh
Copy link
Contributor Author

@bkryu @yongwww could you trigger a pipeline for me again? thanks!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
flashinfer/prefill.py (1)

2280-2281: Optimization doesn't handle scalar tensors containing 1.0.

The optimization correctly avoids multiplication when v_scale is a float equal to 1.0. However, since v_scale can also be a torch.Tensor (per the type hint at line 2056), the check isinstance(v_scale, float) will miss scalar tensors containing 1.0, and unnecessary multiplication will still occur.

Consider extending the optimization:

🔎 Proposed enhancement
-            is_float_one = isinstance(v_scale, float) and v_scale == 1.0
-            if v_scale is not None and not is_float_one:
+            is_one = (isinstance(v_scale, float) and v_scale == 1.0) or \
+                     (isinstance(v_scale, torch.Tensor) and v_scale.numel() == 1 and v_scale.item() == 1.0)
+            if v_scale is not None and not is_one:
flashinfer/decode.py (2)

1065-1089: Consider extracting backend-specific argument logic to reduce duplication.

The backend-specific argument handling (appending fixed_split_size, disable_split_kv, num_colocated_ctas for fa2) is duplicated between the plan() method (lines 1065-1089) and fast_decode_plan() (lines 2665-2689). Consider extracting this into a helper function to reduce maintenance burden.

🔎 Example refactor to eliminate duplication
def _build_plan_args_for_tensor_cores(
    self,
    qo_indptr_host,
    indptr_host,
    kv_lens_arr_host,
    batch_size,
    num_qo_heads,
    num_kv_heads,
    page_size,
    head_dim,
    window_left,
    fixed_split_size,
    disable_split_kv,
):
    """Build plan arguments for tensor core decode."""
    args = [
        self._float_workspace_buffer,
        self._int_workspace_buffer,
        self._pin_memory_int_workspace_buffer,
        qo_indptr_host,
        indptr_host,
        kv_lens_arr_host,
        batch_size,
        batch_size,
        num_qo_heads,
        num_kv_heads,
        page_size,
        self.is_cuda_graph_enabled,
        head_dim,
        head_dim,
        False,  # causal
        window_left,
    ]
    if self._backend == "fa2":
        args.extend([fixed_split_size, disable_split_kv, 0])  # num_colocated_ctas
    return args

Then use: args = self._build_plan_args_for_tensor_cores(...)


1341-1360: *Add comment documenting expected args format for FP8 inputs.

The FP8 scale extraction logic assumes that when q is FP8, the first three elements of *args are fp8_scale_q, fp8_scale_k, and fp8_scale_v (lines 1345-1348). This implicit contract is fragile and could lead to bugs if the calling convention changes.

🔎 Add clarifying comment
             if self._jit_module is not None:
                 run_args.extend(list(args))
             else:
                 # Extract FP8 scale tensors from *args if q is FP8
+                # Expected format: args[0]=fp8_scale_q, args[1]=fp8_scale_k, args[2]=fp8_scale_v
                 fp8_scale_q = None
                 fp8_scale_k = None
                 fp8_scale_v = None
                 if is_float8(q) and len(args) >= 3:
                     fp8_scale_q = args[0]
                     fp8_scale_k = args[1]
                     fp8_scale_v = args[2]
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 121a73e and 1477978.

📒 Files selected for processing (5)
  • flashinfer/decode.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/prefill.py
  • tests/attention/test_hopper_fp8_attention.py
  • tests/utils/test_jit_example.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/utils/test_jit_example.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_hopper_fp8_attention.py (3)
flashinfer/utils.py (1)
  • is_sm90a_supported (531-533)
flashinfer/decode.py (7)
  • BatchDecodeWithPagedKVCacheWrapper (593-1458)
  • plan (824-1128)
  • plan (1653-1766)
  • run (1158-1171)
  • run (1174-1187)
  • run (1190-1422)
  • run (1769-1893)
benchmarks/bench_hopper_fp8_attention.py (1)
  • per_head_symmetric_quant (27-38)
flashinfer/prefill.py (2)
flashinfer/logits_processor/types.py (1)
  • dtype (126-130)
flashinfer/utils.py (1)
  • is_float8 (158-159)
🪛 Ruff (0.14.10)
flashinfer/prefill.py

2118-2120: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (13)
flashinfer/jit/attention/modules.py (1)

987-992: LGTM: Good defensive validation for backend and output dtype.

The assertions appropriately prevent invalid configurations from reaching the backend-specific code generation logic. The error messages are clear and actionable.

flashinfer/prefill.py (2)

1705-1705: LGTM: Documentation correctly expanded to include bfloat16.

The updated guidance appropriately reflects that both torch.float16 and torch.bfloat16 are valid output data types when using FP8 inputs.


2116-2120: Good dtype consistency check between plan and run phases.

The validation ensures the output tensor dtype matches what was specified during planning, which helps catch configuration errors early. The error message clearly indicates the mismatch.

Note: The static analysis hint (TRY003) about message length is a minor style suggestion and can be safely ignored here since the error context is specific to the runtime values.

flashinfer/decode.py (6)

886-888: LGTM! Clear documentation for FP8 output dtype handling.

The documentation clearly explains that o_data_type defaults to q_data_type and provides guidance for FP8 inputs (use torch.float16 or torch.bfloat16). This is a well-designed API addition.


974-987: LGTM! Proper o_data_type handling and caching.

The code correctly canonicalizes o_data_type and defaults it to q_data_type when None, consistent with the documentation. Caching in _cached_o_data_type enables the run() method to use the planned output dtype.


1042-1050: LGTM! Backend auto-selection logic is sound.

The code appropriately calls determine_attention_backend to resolve backend="auto" to either "fa2" or "fa3" based on device capabilities and encoding mode. The resolved backend is stored in self._backend for subsequent use.


1307-1313: LGTM! Safe output tensor creation with fallback.

The output tensor creation correctly uses _cached_o_data_type if available (set by plan()), with a safe fallback to q.dtype. This ensures FP8 inputs can produce fp16/bf16 outputs as specified.


1414-1420: LGTM! Efficient v_scale handling skips unnecessary multiplication.

The optimization to skip v_scale multiplication when v_scale == 1.0 avoids unnecessary computation. The exact equality check is safe for the common case of v_scale=1.0.


720-732: The concern about backend="auto" being passed with jit_args is valid but already properly handled. The gen_customize_batch_prefill_module function includes explicit validation that raises a clear ValueError with the message "backend should not be auto when jit_args is provided" if "auto" is passed. This is intentional behavior: JIT modules require an explicit backend specification (fa2, fa3, etc.) rather than automatic resolution. No changes needed.

tests/attention/test_hopper_fp8_attention.py (4)

667-684: LGTM! Well-documented test for FP8 decode with paged KV cache.

The test function is clearly documented and uses a GQA configuration to thoroughly test head mapping logic with FP8 quantized KV cache. The comment explaining that BatchDecodeWithPagedKVCacheWrapper uses the prefill backend under the hood is helpful context.


735-781: LGTM! Proper FP8 validation against FP16 reference.

The test correctly establishes an FP16 baseline and compares it against the FP8 output. Both wrappers use consistent configuration (backend="fa3", use_tensor_cores=True), ensuring a fair comparison. The per-head varying KV data helps detect head offset bugs in the FP8 kernel.


948-956: LGTM! Test properly invoked in main block.

The test invocation correctly exercises the decode path with appropriate parameters (GQA configuration with num_qo_heads=8, num_kv_heads=2).


783-785: Verify the strict MSE threshold doesn't cause test flakiness.

The MSE threshold of 0.01 is 100x stricter than the 1.0 used in prefill tests (e.g., line 524). While decode operations with single query tokens may have lower quantization error, such a tight threshold could cause flakiness across different hardware (A100 vs H100), CUDA versions, or random seeds.

Consider whether 0.01 is reliably achievable or if it should be relaxed to 0.1 or 0.5 for better test stability.

#!/bin/bash
# Description: Run the test multiple times to check for flakiness
# Expected: Consistent pass rate across runs

for i in {1..10}; do
  echo "Run $i:"
  python -m pytest tests/attention/test_hopper_fp8_attention.py::test_batch_decode_paged -v
done

@bkryu
Copy link
Collaborator

bkryu commented Dec 30, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !211 has been updated with latest changes, and the CI pipeline #40946323 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #40946323: 12/20 passed

@nvpohanh
Copy link
Contributor Author

nvpohanh commented Jan 2, 2026

@bkryu could you merge this? thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants