-
Notifications
You must be signed in to change notification settings - Fork 621
Enable Hopper FA3 FP8 attention in decode.py #2148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Enable Hopper FA3 FP8 attention in decode.py #2148
Conversation
📝 WalkthroughWalkthroughAdds an explicit output data-type parameter ( Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
please refer to #2111 where we refactored fa3 and exposed the fp8 interface to python |
ede67a3 to
a8d9e6a
Compare
ee77217 to
83cfce9
Compare
83cfce9 to
09a1ece
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
flashinfer/prefill.py (1)
2110-2114: Stronger out-dtype validation; consider tightening the error messageThe explicit check that
out.dtypematches the plannedo_data_typeenforces the plan/run contract and prevents silent dtype mismatches when callers reuse anoutbuffer. This is a solid safety improvement.Ruff’s TRY003 warning about long exception messages could be addressed by shortening the message slightly or moving it into a shared constant/helper, but that’s stylistic and not functionally required.
flashinfer/decode.py (1)
1306-1313: Minor suggestion: alignoutshape check with allocation expressionRight now
outis allocated withq.shape[:-1] + v_cache.shape[-1:]but validated againstq.shape. Those are equal for today’s kernels (q and v share head_dim), but if q/v head dims ever diverge, the validation would become inconsistent. Consider using the same expression in both places for future-proofing.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/decode.py(17 hunks)flashinfer/jit/attention/modules.py(1 hunks)flashinfer/prefill.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (5)
flashinfer/utils.py (3)
determine_attention_backend(450-492)canonicalize_torch_dtype(241-249)is_float8(158-159)flashinfer/logits_processor/types.py (2)
dtype(126-130)device(119-123)include/flashinfer/trtllm/common.h (1)
device(83-90)flashinfer/attention.py (1)
plan(71-136)flashinfer/pod.py (2)
plan(265-434)plan(800-1014)
🪛 Ruff (0.14.8)
flashinfer/prefill.py
2112-2114: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (10)
flashinfer/jit/attention/modules.py (1)
987-992: Good preflight validation for backend and FP8 output dtypeThe explicit
backendwhitelist anddtype_oFP8 guard make the FA2/FA3 path fail fast on unsupported configurations and align with the FP8 constraints used elsewhere in prefill/decode. Looks correct and consistent with the intended Hopper FA3 FP8 support.flashinfer/prefill.py (3)
1704-1705: FP8 output dtype guidance in docstrings is clear and alignedThe added note that
o_data_typefor FP8 inputs should typically betorch.float16ortorch.bfloat16matches the actual capabilities and helps users avoid unsupported FP8 outputs. No further changes needed here.Also applies to: 2656-2657
2080-2082: q_scale documentation matches runtime usageThe new
q_scaledescription correctly reflects how it’s applied (folded intosm_scalefor FP8 BMM1). This keeps the public API understandable for FP8 users.
2268-2273: Conditional v_scale application is correct and avoids unnecessary workApplying
v_scaleonly when it’s notNoneand not1.0preserves previous behavior while saving a redundant multiply on the common default. The FP8 branch (cast to fp32, scale, cast back) is also consistent with typical scaling practice.flashinfer/decode.py (6)
51-76: Backend auto-selection and FA2/FA3 plan args look consistent and safeImporting and using
determine_attention_backendto specializeself._backendwhen it is"auto"in the tensor-core path, and then branching FA2-only plan arguments (fixed_split_size,disable_split_kv,num_colocated_ctas) while keeping FA3 with the shorter signature, matches the expected separation of FA2/FA3 interfaces and mirrors the logic infast_decode_plan. This keeps decode aligned with prefill and should enable FA3 on Hopper cleanly.Also applies to: 1042-1050, 2635-2660
720-733: Passingbackendintogen_customize_batch_prefill_moduleis the right directionWiring
backendexplicitly intogen_customize_batch_prefill_modulefor the tensor-core JIT path aligns decode with the prefill side and makes backend selection explicit at module generation time. Assuming the JIT generator’s signature matches this order, this is a straightforward and correct extension of the existing JIT flow.
838-840:o_data_typethreading and validation are coherentThe new
o_data_typeparameter is:
- Defaulted to
q_data_typewhen not provided, then canonicalized.- Cached as
_cached_o_data_typeand threaded intoget_trtllm_gen_decode_module,get_batch_prefill_module, andget_batch_decode_module.- Used at run time for both allocation and
check_shape_dtype_devicevalidation ofout.This is consistent with the rest of the dtype handling and gives callers explicit control over output dtype (including FP8-input → FP16/BF16-output scenarios) without breaking older call sites that omit
o_data_type.Also applies to: 886-889, 964-977, 985-988, 1025-1035, 1095-1104, 1306-1313
1341-1361: FP8 scale wiring and v_scale optimization look correctExtracting
fp8_scale_q/k/vfrom*argsonly whenqis float8 and only for the non-JIT tensor-core path keeps the existing API surface intact while enabling FA3/FA2 FP8 usage. Passing these explicitly intopaged_runmatches the extended kernel signature, and the updatedv_scaleguard (v_scale is not None and v_scale != 1.0) plusis_float8(out)-based cast behavior are sensible micro-optimizations that preserve numerical behavior.Also applies to: 1413-1418
998-1005: trtllm-gen decode integration witho_data_typeremains consistentAdding
o_data_typeto theget_trtllm_gen_decode_modulecache key while still using the samepaged_runargument layout (includingworkspace_size,block_tables,kv_lens_buffer, etc.) keeps the trtllm-gen path coherent with FA2/FA3 and ensures different output dtypes don’t collide in the module cache. The subsequent_paged_runcall continues to receive the expected scales and workspace sizing.Also applies to: 1364-1373
2523-2590:fast_decode_planargument construction now mirrors mainplanThe tensor-core branch in
fast_decode_plannow builds the same baseargslist asBatchDecodeWithPagedKVCacheWrapper.planand appends FA2-only arguments underself._backend == "fa2". This keeps the “fast” path in sync with the standard planner and reduces the risk of FA3/FA2 divergence for multistep decode.Also applies to: 2635-2660
|
@yzh119 Could you review this PR? Thanks! Let me know if you think I should add some tests for this. If so, please point me to the test file where I should add/extend the tests. Thanks! |
bkryu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @nvpohanh, I'd say the appropriate place to put unit tests should be test_hopper.py or test_hopper_fp8_attention.py.
Since we are adding fa3 as a backend for the first time in BatchDecodeWithPagedKVCacheWrapper, we may need to write new tests analogous to the prefill ones in there.
09a1ece to
1bb22d5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)
2279-2284: Fix tensor comparison in v_scale condition.The condition
v_scale != 1.0will fail when v_scale is a torch.Tensor with multiple elements (e.g., per-head scaling). Tensor comparisons return boolean tensors, which cannot be used directly in if statements and will raise:RuntimeError: The truth value of a Tensor with more than one element is ambiguous.The v_scale parameter accepts
Optional[Union[float, torch.Tensor]]per the type annotation, and the codebase documentation shows per-head tensor scales with shape[num_kv_heads]are supported in other functions. Additionally, q_scale and k_scale (same type) are handled differently at lines 2140-2143 without this comparison.Apply this fix to handle both scalar and tensor v_scale correctly:
- if v_scale is not None and v_scale != 1.0: + if v_scale is not None: + # Skip scaling if v_scale is scalar 1.0 + if isinstance(v_scale, (int, float)) and v_scale == 1.0: + pass + else: # TODO(Zihao): fused into kernel if is_float8(out): out = (out.to(torch.float32) * v_scale).to(out.dtype) else: out *= v_scaleAlternatively, add explicit validation in the function to document whether v_scale must be scalar if a tensor.
🧹 Nitpick comments (2)
flashinfer/jit/attention/modules.py (1)
987-993: Consider using explicit ValueError instead of assert for input validation.While assertions work for catching invalid inputs during development, they can be disabled with
python -Oin production environments. For input validation that should always run, explicitifstatements withValueErrorare more reliable:- assert backend in ["fa2", "fa3"], ( - f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}" - ) - assert dtype_o not in [torch.float8_e4m3fn, torch.float8_e5m2], ( - "FP8 output is not supported in fa2/fa3 backends yet" - ) + if backend not in ["fa2", "fa3"]: + raise ValueError( + f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}" + ) + if dtype_o in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError( + "FP8 output is not supported in fa2/fa3 backends yet" + )flashinfer/decode.py (1)
709-712: Clarify backend parameter documentation to mention FA3.The past review comment by bkryu is still relevant. While the
backendparameter accepts"auto","fa2", or"trtllm-gen", the docstring should clarify that"auto"may internally select FA3 on supported hardware (as implemented in lines 1042-1050). This would help users understand the full range of backends that may be used.Apply this diff to improve the documentation:
backend : str - The implementation backend, could be ``auto``/``fa2`` or ``trtllm-gen``. Defaults to ``auto``. - If set to ``auto``, the wrapper will automatically choose the backend based on the - device architecture and kernel availability. + The implementation backend, could be ``auto``/``fa2`` or ``trtllm-gen``. Defaults to ``auto``. + If set to ``auto``, the wrapper will automatically choose the backend based on the + device architecture and kernel availability (may select FA2 or FA3 for tensor cores).
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/decode.py(17 hunks)flashinfer/jit/attention/modules.py(1 hunks)flashinfer/prefill.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (1)
flashinfer/utils.py (4)
determine_attention_backend(455-497)canonicalize_torch_dtype(246-254)PosEncodingMode(32-35)is_float8(158-159)
🪛 Ruff (0.14.8)
flashinfer/prefill.py
2118-2120: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (16)
flashinfer/jit/attention/modules.py (1)
994-1011: LGTM! Expanded FA2 backend tensor support.The addition of new tensor names (mask_indptr, prefix_len_ptr, token_pos_in_items_ptr, max_item_len_ptr) with their corresponding dtypes appropriately extends the FA2 backend capabilities to support the FP8 attention features.
flashinfer/prefill.py (3)
1705-1705: LGTM! Helpful FP8 guidance added to docstrings.The added guidance for o_data_type with FP8 inputs is clear and informative, helping users understand that they should typically use torch.float16 or torch.bfloat16 as the output dtype when working with FP8 attention.
2116-2121: LGTM! Important runtime dtype validation.The validation ensures consistency between the output tensor's dtype and the o_data_type specified during planning, preventing potential dtype mismatches that could cause runtime errors or incorrect results.
Note: The static analysis hint (TRY003) about message length is a minor style preference. The current message provides helpful diagnostic information and is acceptable.
2667-2667: LGTM! Consistent FP8 guidance.Same helpful FP8 guidance as added to the paged variant's docstring (line 1705), ensuring consistent documentation across both BatchPrefillWithPagedKVCacheWrapper and BatchPrefillWithRaggedKVCacheWrapper.
flashinfer/decode.py (12)
66-66: LGTM: Import necessary for auto backend selection.The
determine_attention_backendimport is required for the auto backend selection logic introduced later in theplanmethod (lines 1042-1050).
838-838: LGTM: o_data_type parameter addition aligns with FP8 support.The addition of the
o_data_typeparameter and its documentation clearly supports FP8 attention workflows, where output dtype often differs from input dtype. The default behavior (falling back toq_data_type) maintains backward compatibility.Also applies to: 886-888
974-977: LGTM: o_data_type handling is consistent with existing patterns.The canonicalization and caching of
o_data_typefollows the same pattern asq_data_typeandkv_data_type, ensuring consistency across the codebase.Also applies to: 987-987
1024-1036: LGTM: o_data_type properly threaded through all backend paths.The
o_data_typeparameter is correctly passed to all three module creation functions (get_trtllm_gen_decode_module,get_batch_prefill_module, andget_batch_decode_module), ensuring consistent output dtype handling across FA2, FA3, and TRTLLM-gen backends.Also applies to: 1051-1063, 1094-1104
1042-1050: LGTM: Auto backend selection correctly implemented.The auto backend selection logic properly delegates to
determine_attention_backendwith appropriate parameters. The backend is selected based on device capabilities and input dtypes, which is the correct approach since backend capability is primarily determined by what inputs it can process.
1065-1089: LGTM: Conditional plan arguments correctly handle FA2/FA3 differences.The conditional logic properly handles the different plan signatures for FA2 (19 arguments) and FA3 (16 arguments) backends. FA2-specific parameters (
fixed_split_size,disable_split_kv,num_colocated_ctas) are only appended when using the FA2 backend.Note: When
jit_moduleis provided (line 1039), the auto backend selection at lines 1042-1050 is skipped. Ensure that when using custom JIT modules withbackend="auto", users are aware that they need to specify the correct backend explicitly.
1306-1313: LGTM: Output tensor handling correctly uses o_data_type.The output allocation and validation logic properly uses
_cached_o_data_typewith a sensible fallback toq.dtype. This ensures that FP8 workflows can produce outputs in the desired dtype (typically FP16/BF16) while maintaining backward compatibility for existing code.
1341-1363: LGTM: FP8 scale tensor extraction enables FP8 attention support.The extraction and propagation of FP8 scale tensors (
fp8_scale_q,fp8_scale_k,fp8_scale_v) from*argscorrectly enables FP8 attention workflows. The conditional extraction (only whenis_float8(q)and sufficient args are provided) is appropriate and aligns with the PR objective.
1413-1418: LGTM: v_scale optimization avoids unnecessary computation.The refined condition
if v_scale is not None and v_scale != 1.0appropriately skips the no-op scaling whenv_scaleis 1.0, improving performance while maintaining correctness.
2635-2660: LGTM: fast_decode_plan correctly mirrors regular plan logic.The conditional FA2/FA3 argument handling in
fast_decode_plancorrectly mirrors the logic in the mainplanmethod (lines 1065-1089), ensuring consistency between the fast and standard plan implementations. The comment at line 2635 clearly documents the argument count difference.
1341-1363: Verify FP8 scale tensor ordering consistency with prefill.py implementation.Confirm that the FP8 scale tensor extraction and ordering (lines 1345-1347:
fp8_scale_q = args[0],fp8_scale_k = args[1],fp8_scale_v = args[2]) and their positions inrun_args(lines 1358-1360) match the prefill.py implementation. This ensures consistency between decode and prefill operations for maintainability and correctness.
720-732: No backend parameter needed for decode module generation.After examining the function implementations,
gen_customize_batch_decode_moduleintentionally does not accept abackendparameter because it has only a single implementation path without backend-specific variations. In contrast,gen_customize_batch_prefill_moduleacceptsbackendbecause it includes conditional logic for multiple backend implementations (fa2, fa3, etc.). The code correctly passesbackendonly to the prefill module. This is by design, not an inconsistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/utils/test_jit_example.py (1)
254-260: Explicit backend specification improves test clarity.The addition of
backend="fa2"makes the backend selection explicit, which aligns well with the wrapper configuration and the SM80 variant declaration used in this test.However, consider adding a corresponding SM90/FA3 test for batch decode (similar to
test_batch_prefill_sm90_flash_sigmoidat line 459) to validate the PR's main objective of enabling Hopper FA3 FP8 attention in decode.py. Currently, decode operations only have FA2 test coverage, while prefill has both FA2 and FA3 tests.Would you like me to help generate a
test_batch_decode_sm90_flash_sigmoidtest that uses the FA3 backend with the SM90 variant declaration to fully validate the FA3 FP8 functionality in decode.py?
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/utils/test_jit_example.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
7c59a4c to
f841c9c
Compare
|
added a test to test_hopper_fp8_attention.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/attention/test_hopper_fp8_attention.py (1)
1-1: Critical: Fix ruff-format errors before mergeThe pre-commit
ruff-formathook failed, indicating that the file needs formatting. Please run the following command locally to fix formatting issues:pre-commit run ruff-format --files tests/attention/test_hopper_fp8_attention.pyOr format all files:
pre-commit run --all-files
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
flashinfer/decode.py(18 hunks)flashinfer/jit/attention/modules.py(1 hunks)flashinfer/prefill.py(4 hunks)tests/attention/test_hopper_fp8_attention.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/attention/test_hopper_fp8_attention.py (4)
flashinfer/utils.py (1)
is_sm90a_supported(531-533)flashinfer/decode.py (9)
BatchDecodeWithPagedKVCacheWrapper(593-1456)use_tensor_cores(792-793)use_tensor_cores(1623-1624)plan(824-1128)plan(1651-1764)run(1158-1171)run(1174-1187)run(1190-1420)run(1767-1891)flashinfer/prefill.py (8)
plan(1595-1988)plan(2570-2867)run(2019-2031)run(2034-2046)run(2049-2285)run(2897-2907)run(2910-2920)run(2923-3098)benchmarks/bench_hopper_fp8_attention.py (1)
per_head_symmetric_quant(27-38)
flashinfer/decode.py (1)
flashinfer/utils.py (5)
determine_attention_backend(455-497)canonicalize_torch_dtype(246-254)PosEncodingMode(32-35)check_shape_dtype_device(565-583)is_float8(158-159)
flashinfer/prefill.py (1)
flashinfer/logits_processor/types.py (1)
dtype(126-130)
🪛 GitHub Actions: pre-commit
tests/attention/test_hopper_fp8_attention.py
[error] 1-1: pre-commit ruff-format hook failed: 1 file reformatted. Run 'pre-commit run --all-files' locally to apply formatting changes.
🪛 Ruff (0.14.8)
flashinfer/prefill.py
2118-2120: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (20)
flashinfer/jit/attention/modules.py (1)
987-992: Validation checks look good.The backend and output dtype validations are well-placed at the function entry point, ensuring invalid configurations are caught early. The error messages are clear and informative.
flashinfer/prefill.py (4)
1705-1705: Good documentation improvement for FP8 attention.The guidance clarifies that FP8 inputs should typically use float16 or bfloat16 output, which aligns with the PR's FP8 attention support objectives.
2116-2120: Good runtime validation for output dtype consistency.This check ensures that when a user provides an output tensor, its dtype matches what was specified in the
plan()call, preventing subtle bugs. The error message is clear and includes the actual dtype values for easy debugging.Note: The static analysis tool flags a style preference (TRY003) about the exception message length, but this is a minor nitpick and the current implementation prioritizes clarity.
2667-2667: Consistent documentation improvement.Same helpful guidance as line 1705, applied consistently to the ragged KV cache wrapper's plan method.
2279-2284: The code is correct. The actual implementation usesisinstance(v_scale, float)to guard the comparison, properly handling the case wherev_scalecan be either a float or a Tensor. The review comment's code snippet does not match the repository.Likely an incorrect or invalid review comment.
flashinfer/decode.py (10)
66-66: LGTM: FA3 backend support additionsThe import of
determine_attention_backend, docstring update mentioning FA3, and passing the backend parameter to the customized module are all correct and necessary for FA3 support.Also applies to: 710-710, 725-725
838-838: LGTM: Output data type parameter added to public APIThe addition of
o_data_typeparameter to theplan()signature is well-documented and follows the same pattern askv_data_type. The documentation correctly indicates that for FP8 inputs, the output should typically betorch.float16ortorch.bfloat16.Also applies to: 886-889
974-977: LGTM: Output data type canonicalization and cachingThe canonicalization and defaulting logic for
o_data_typefollows the same pattern askv_data_type, and caching it in_cached_o_data_typeis consistent with the existing caching pattern.Also applies to: 987-987
1042-1050: LGTM: Backend auto-selection logicThe automatic backend determination for
use_tensor_cores=Truepaths correctly usesdetermine_attention_backendto choose between FA2 and FA3 based on device capabilities and data types. The non-tensor-cores path doesn't require backend selection, so this implementation is correct.
1027-1027: LGTM: Output data type propagation to modulesThe
o_data_typeparameter is correctly threaded through to all three module types (trtllm-gen, batch prefill, and batch decode), ensuring that the output data type is properly configured across different backends.Also applies to: 1052-1052, 1055-1055, 1097-1097
1065-1089: LGTM: Backend-specific plan argument handlingThe conditional argument handling correctly differentiates between FA2 (which requires
fixed_split_size,disable_split_kv, andnum_colocated_ctas) and FA3 (which doesn't support these parameters). The base argument list has 16 parameters, with FA2 adding 3 more for a total of 19.
1307-1313: LGTM: Output tensor allocation with cached data typeThe output tensor allocation and validation correctly use
_cached_o_data_type(with safe fallback toq.dtype) to ensure the output tensor has the planned data type. This is consistent with the pattern used in prefill.py.
1341-1360: LGTM: FP8 scale tensor extraction for FA3 pathThe extraction of FP8 scale tensors from
*argswhenqis FP8 is correctly guarded byis_float8(q) and len(args) >= 3to avoid indexing errors. The scales default toNoneif not provided, which the kernel should handle appropriately.
1413-1413: LGTM: Optimize v_scale applicationThe additional check
v_scale != 1.0correctly skips the scaling operation when no scaling is needed, avoiding unnecessary work. This optimization aligns with the similar change in prefill.py.
2635-2660: LGTM: Consistent FA2/FA3 argument handling in fast_decode_planThe
fast_decode_planfunction correctly mirrors the backend-specific argument handling from the mainplanmethod, ensuring FA2 receives its three additional parameters while FA3 uses only the base 16 arguments.tests/attention/test_hopper_fp8_attention.py (5)
667-691: LGTM: Test function setupThe test parameters and setup are appropriate for validating FP8 batch decode with paged KV cache. The GQA scenarios (different
num_qo_headsandnum_kv_headscombinations) are particularly valuable for testing head mapping logic.
693-736: LGTM: Paged KV cache constructionThe test correctly constructs paged KV cache with per-head varying scales (using
create_per_head_varying_kv), which helps reveal head offset bugs. The query shape[batch_size, num_qo_heads, head_dim]is appropriate for decode (no sequence length dimension).
738-754: LGTM: FP16 reference computationThe FP16 reference computation correctly uses
BatchDecodeWithPagedKVCacheWrapperwithuse_tensor_cores=Trueandbackend="fa3", which is appropriate since decode uses prefill kernels under the hood when tensor cores are enabled.
786-788: Verify MSE threshold is intentionalThe MSE threshold is
0.01, which is 100× stricter than the1.0threshold used in other FP8 tests (e.g.,test_batch_prefill_pagedat line 524). This stricter threshold might be intentional since decode has simpler computation (single query per request) with less accumulation error. However, please verify this is the intended threshold and not a typo.
951-959: LGTM: Test execution added to main blockThe new decode test is correctly integrated into the
__main__execution block with a reasonable subset of parameters for quick local validation.
e3e9279 to
6ecd187
Compare
|
@bkryu could you review again? thanks! |
6ecd187 to
92c42b1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tests/attention/test_hopper_fp8_attention.py (1)
667-785: Decode FP8 paged‑KV GQA test is well‑structured; consider adding e5m2 coverageThe new
test_batch_decode_pagedmirrors the FP8 paged‑prefill GQA tests, exercisesBatchDecodeWithPagedKVCacheWrapperunder FA3 with per‑head‑varying KV, and uses a tight MSE threshold — this should reliably catch head‑offset or backend wiring issues in the new decode path.You might optionally extend
dtypeto includetorch.float8_e5m2here as well, to mirror the broader FP8 coverage used in the prefill tests.flashinfer/prefill.py (1)
2116-2120: o_data_type caching and v_scale handling align plan/run behavior for FP8 paths
- Using
self._cached_o_data_typeboth to validate a user‑providedouttensor and to choose the dtype when allocatingoutensures consistency with theo_data_typepassed toplan(). This is especially important now that FP8 inputs often request FP16/BF16 outputs.- The explicit check that
out.dtypematches the plannedo_data_typewill fail fast on configuration mistakes instead of silently running kernels with a mismatched dtype.- Applying
v_scaleonly when it is a Pythonfloat(and not equal to 1.0) prevents accidental double‑scaling when FP8 per‑head/per‑tensor scales are passed as tensors; tensor scales are already consumed inside the kernel via the JIT parameters.If you want to appease Ruff’s TRY003 warning, you could shorten or slightly rephrase the
ValueErrormessage, but that’s stylistic rather than functional.Also applies to: 2158-2168, 2279-2284
flashinfer/decode.py (1)
2635-2660: Consider adding o_data_type support to fast_decode_plan.The backend-specific argument handling is correctly updated to match the
planmethod. However,fast_decode_plandoesn't include theo_data_typeparameter that was added toplan. If this function needs to support FP8 outputs in the future, consider adding theo_data_typeparameter.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
flashinfer/decode.pyflashinfer/jit/attention/modules.pyflashinfer/prefill.pytests/attention/test_hopper_fp8_attention.pytests/utils/test_jit_example.py
🧰 Additional context used
🧬 Code graph analysis (3)
tests/attention/test_hopper_fp8_attention.py (3)
flashinfer/utils.py (1)
is_sm90a_supported(531-533)flashinfer/decode.py (7)
BatchDecodeWithPagedKVCacheWrapper(593-1456)plan(824-1128)plan(1651-1764)run(1158-1171)run(1174-1187)run(1190-1420)run(1767-1891)benchmarks/bench_hopper_fp8_attention.py (1)
per_head_symmetric_quant(27-38)
flashinfer/prefill.py (1)
flashinfer/logits_processor/types.py (1)
dtype(126-130)
flashinfer/decode.py (1)
flashinfer/utils.py (5)
determine_attention_backend(455-497)canonicalize_torch_dtype(246-254)PosEncodingMode(32-35)check_shape_dtype_device(565-583)is_float8(158-159)
🪛 Ruff (0.14.8)
flashinfer/prefill.py
2118-2120: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (13)
tests/utils/test_jit_example.py (1)
254-260: Explicit fa2 backend looks correct herePinning
backend="fa2"for the SM80 flash‑sigmoid JIT module avoidsautoaccidentally picking fa3 on newer GPUs and makes the test behavior deterministic across architectures.flashinfer/jit/attention/modules.py (1)
985-993: Backend / dtype guards are consistent with FA2/FA3 FP8 designThe new assertions restricting
backendto{"fa2","fa3"}and rejecting FP8dtype_oalign with howget_batch_prefill_moduleis used and with the FP8 tests (FP8 inputs, non‑FP8 outputs). This should prevent misconfigured JIT builds from other paths.flashinfer/decode.py (11)
66-66: LGTM: Import added for backend auto-selection.The
determine_attention_backendimport is necessary for the new auto-selection logic added in lines 1042-1050.
710-710: LGTM: Documentation updated to reflect FA3 backend support.The docstring correctly documents that
fa3is now a supported backend option, aligning with the PR objective.
838-838: LGTM: Output data type parameter added with clear documentation.The
o_data_typeparameter is properly documented and follows the same pattern asq_data_typeandkv_data_type, with helpful guidance for FP8 usage scenarios.Also applies to: 886-889
974-977: LGTM: Data type canonicalization follows established patterns.The
o_data_typecanonicalization and caching logic is consistent with the existing handling ofq_data_typeandkv_data_type, with appropriate defaulting toq_data_typewhen not specified.Also applies to: 987-987
1042-1050: LGTM: Backend auto-selection properly implemented.The auto-selection logic correctly uses
determine_attention_backendto choose between FA2 and FA3 based on device capabilities and configuration, with appropriate parameters passed.
1027-1027: LGTM: Output data type consistently threaded through module creation.The
o_data_typeparameter is correctly passed to all module creation functions across different code paths (trtllm-gen, tensor cores with batch prefill, and standard batch decode), ensuring consistent output type handling.Also applies to: 1051-1063, 1094-1097
1065-1089: LGTM: Backend-specific plan arguments properly handled.The code correctly handles FA2-specific parameters (
fixed_split_size,disable_split_kv,num_colocated_ctas) by only appending them when backend is "fa2", avoiding passing unsupported arguments to FA3 backend.
1307-1313: LGTM: Output tensor handling respects planned data type.The output tensor creation and validation correctly uses
_cached_o_data_typewhen available (set duringplan), with appropriate fallback toq.dtypefor backward compatibility.
725-725: LGTM: Backend parameter threaded through JIT module creation.The backend parameter is correctly passed to
gen_customize_batch_prefill_modulewhen using custom JIT arguments, ensuring consistent backend selection across all code paths.
1413-1418: The optimization is safe and correctly constrained to float scales.The condition change from
if v_scale is not Nonetoif isinstance(v_scale, float) and v_scale != 1.0is valid because v_scale is defined as Optional[float] in the public API. This optimization avoids unnecessary scaling operations when v_scale equals 1.0 (identity scaling). Tensor v_scale values are not supported by the API, so the isinstance check is appropriate and safe. No changes needed.
1341-1360: Ensure FP8 scale handling is consistent across both tensor-core and non-tensor-core code paths.FP8 scale tensors (
fp8_scale_q,fp8_scale_k,fp8_scale_v) are extracted from*argsand passed to the kernel for the tensor-core path (lines 1341-1348), but the non-tensor-core path (lines 1404-1410) omits this extraction and passing. If FP8 is intended to be supported in both paths, add the same extraction and passing logic to the non-tensor-core branch; otherwise, document why FP8 is tensor-core-only.
|
@bkryu could you review again? thanks! |
|
/bot run |
bkryu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nvpohanh, the Hopper unit tests are passing but the non-sm90 arch unit tests are failing. See the UT pipeline's #40635547 Blackwell jobs -- they are failing test_trtllm_gen_attention.py.
I suspect that there is some unintended change to the kernel launch or logic, or may be something trivial as needing a rebase. Can you see if you can fix the unit test files? Afterwards, the PR seems good to go as it passes on Hopper.
92c42b1 to
121a73e
Compare
|
The CI failure is indeed caused by my change. I have pushed a fix now. Let's see how the new CI goes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
flashinfer/prefill.py (1)
2117-2120: Newoutdtype check against plannedo_data_typeis a good safety netValidating
out.dtypeagainst the cachedo_data_typefromplan()will catch mismatches early and avoids harder-to-diagnose JIT/kernel failures later. The error message is explicit; if you care about Ruff’s TRY003 warning, you could trim or factor the message, but it’s fine functionally as-is.flashinfer/decode.py (3)
1065-1089: Consider adding an explanatory comment for backend branching.The backend-dependent argument list construction is correct—FA2 requires 19 plan arguments (including
fixed_split_size,disable_split_kv,num_colocated_ctas) while FA3 requires only 16. However, a brief inline comment explaining this difference would improve maintainability.Suggested comment
args = [ self._float_workspace_buffer, # ... (base args) window_left, ] + # FA2 backend requires additional split-kv configuration parameters if self._backend == "fa2": args.append(fixed_split_size) args.append(disable_split_kv) args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, )
1307-1313: Consider validating that plan() was called before run().The output tensor creation correctly uses
_cached_o_data_typewith a fallback toq.dtype, but ifplan()was never called, the fallback might produce unexpected output dtypes (especially for FP8 inputs where output should typically be float16/bfloat16).Consider adding an assertion or warning if
_cached_o_data_typeis not set, to catch incorrect API usage early.Example validation
+ if not hasattr(self, "_cached_o_data_type"): + raise RuntimeError( + "plan() must be called before run(). " + "The output data type has not been initialized." + ) if out is None: - out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype + out_dtype = self._cached_o_data_type out = torch.empty( q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device ) else: - out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype + out_dtype = self._cached_o_data_type check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out")
1341-1360: Consider validating FP8 scale tensor types.The FP8 scale extraction assumes that when
qis FP8 andlen(args) >= 3, the first three args are scale tensors. While this follows the established convention, adding type or shape validation would improve robustness and catch incorrect API usage.Example validation
# Extract FP8 scale tensors from *args if q is FP8 fp8_scale_q = None fp8_scale_k = None fp8_scale_v = None if is_float8(q) and len(args) >= 3: fp8_scale_q = args[0] fp8_scale_k = args[1] fp8_scale_v = args[2] + # Validate that extracted scales are tensors (optional but safer) + for scale, name in [(fp8_scale_q, "scale_q"), (fp8_scale_k, "scale_k"), (fp8_scale_v, "scale_v")]: + if scale is not None and not isinstance(scale, torch.Tensor): + raise TypeError(f"FP8 {name} must be a torch.Tensor, got {type(scale)}")
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
flashinfer/decode.pyflashinfer/jit/attention/modules.pyflashinfer/prefill.pytests/attention/test_hopper_fp8_attention.pytests/utils/test_jit_example.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/utils/test_jit_example.py
🧰 Additional context used
🧬 Code graph analysis (3)
tests/attention/test_hopper_fp8_attention.py (3)
flashinfer/utils.py (1)
is_sm90a_supported(531-533)flashinfer/decode.py (7)
BatchDecodeWithPagedKVCacheWrapper(593-1458)plan(824-1128)plan(1653-1766)run(1158-1171)run(1174-1187)run(1190-1422)run(1769-1893)benchmarks/bench_hopper_fp8_attention.py (1)
per_head_symmetric_quant(27-38)
flashinfer/decode.py (1)
flashinfer/utils.py (4)
determine_attention_backend(455-497)canonicalize_torch_dtype(246-254)PosEncodingMode(32-35)is_float8(158-159)
flashinfer/prefill.py (2)
flashinfer/logits_processor/types.py (1)
dtype(126-130)flashinfer/utils.py (1)
is_float8(158-159)
🪛 Ruff (0.14.10)
flashinfer/prefill.py
2118-2120: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (11)
tests/attention/test_hopper_fp8_attention.py (2)
667-785: FA3 FP8 decode paged test wiring looks correct and consistent with prefill testsThe new
test_batch_decode_pagedmirrors the prefill-paged FP8 setup (indptr/indices, paged KV layout, per-head quantization, and GQA head configs) and exercisesBatchDecodeWithPagedKVCacheWrapperwith FA3+FP8 end‑to‑end. Shapes and arguments toplan/runlook consistent with the existing prefill tests; the tighter MSE threshold is reasonable here.
948-956: Manual__main__decode loop is consistent with existing debug harnessThe new
__main__loop fortest_batch_decode_pagedfollows the same pattern as the prefill/debug loops and is fine as a local debugging aid.flashinfer/jit/attention/modules.py (1)
987-993: Backend and FP8-output guards ingen_batch_prefill_moduleare appropriateThe new asserts cleanly constrain
backendtofa2/fa3and forbid FP8dtype_o, which matches current backend capabilities and will fail fast on misconfiguration without affecting valid callers.flashinfer/prefill.py (2)
1705-1705: FP8o_data_typeguidance in docs is accurate and helpfulThe added note that FP8 inputs should typically use fp16 or bf16
o_data_typematches how the kernels and tests are wired and makes the API expectations clearer.Also applies to: 2671-2671
2280-2287: Refinedv_scalepost-scaling avoids redundant work for the common 1.0 caseConditioning the post-kernel multiply on
v_scalebeing not-None and not the scalar1.0preserves existing semantics while skipping unnecessary computation and dtype casts in the default case; tensor scales are still fully supported.flashinfer/decode.py (6)
66-66: LGTM: Import addition supports FA3 backend selection.The
determine_attention_backendimport is correctly added and used later for automatic backend selection based on device capabilities.
838-838: LGTM: Output data type parameter properly threaded through the API.The
o_data_typeparameter is well-documented, correctly defaulted toq_data_type, and consistently propagated through module creation paths. The canonicalization logic follows established patterns for other data type parameters.Also applies to: 886-889, 974-977, 987-987
1042-1050: LGTM: Backend auto-selection properly implemented.The automatic backend determination correctly delegates to
determine_attention_backendwith appropriate parameters, enabling dynamic selection between FA2 and FA3 based on device capabilities and dtype constraints.
1414-1420: LGTM: v_scale optimization correctly skips redundant scaling.The check
isinstance(v_scale, float) and v_scale == 1.0appropriately skips the scaling operation when it would have no effect. The FP8 output handling correctly converts to float32, scales, and converts back to maintain precision.
2664-2689: LGTM: fast_decode_plan correctly mirrors backend-specific argument handling.The backend-dependent plan argument construction in
fast_decode_planis consistent with the mainplan()method, correctly handling the FA2 vs FA3 signature differences. The explicit comment about argument counts is helpful.
710-710: LGTM: Backend documentation and parameter passing are correct.The documentation properly lists
fa3as a supported backend alongsidefa2andtrtllm-gen. The backend parameter is correctly threaded through to the customized module when JIT args are provided.Also applies to: 725-725
|
[FAILED] Pipeline #40635547: 5/20 passed |
|
/bot run |
|
@nvpohanh is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
bkryu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can confirm that the unit tests failures are unrelated to current PR
Changes LGTM.
|
[FAILED] Pipeline #40698807: 10/20 passed |
Signed-off-by: Po-Han Huang <[email protected]>
121a73e to
1477978
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
flashinfer/prefill.py (1)
2280-2281: Optimization doesn't handle scalar tensors containing 1.0.The optimization correctly avoids multiplication when
v_scaleis a float equal to 1.0. However, sincev_scalecan also be atorch.Tensor(per the type hint at line 2056), the checkisinstance(v_scale, float)will miss scalar tensors containing 1.0, and unnecessary multiplication will still occur.Consider extending the optimization:
🔎 Proposed enhancement
- is_float_one = isinstance(v_scale, float) and v_scale == 1.0 - if v_scale is not None and not is_float_one: + is_one = (isinstance(v_scale, float) and v_scale == 1.0) or \ + (isinstance(v_scale, torch.Tensor) and v_scale.numel() == 1 and v_scale.item() == 1.0) + if v_scale is not None and not is_one:flashinfer/decode.py (2)
1065-1089: Consider extracting backend-specific argument logic to reduce duplication.The backend-specific argument handling (appending
fixed_split_size,disable_split_kv,num_colocated_ctasfor fa2) is duplicated between theplan()method (lines 1065-1089) andfast_decode_plan()(lines 2665-2689). Consider extracting this into a helper function to reduce maintenance burden.🔎 Example refactor to eliminate duplication
def _build_plan_args_for_tensor_cores( self, qo_indptr_host, indptr_host, kv_lens_arr_host, batch_size, num_qo_heads, num_kv_heads, page_size, head_dim, window_left, fixed_split_size, disable_split_kv, ): """Build plan arguments for tensor core decode.""" args = [ self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, qo_indptr_host, indptr_host, kv_lens_arr_host, batch_size, batch_size, num_qo_heads, num_kv_heads, page_size, self.is_cuda_graph_enabled, head_dim, head_dim, False, # causal window_left, ] if self._backend == "fa2": args.extend([fixed_split_size, disable_split_kv, 0]) # num_colocated_ctas return argsThen use:
args = self._build_plan_args_for_tensor_cores(...)
1341-1360: *Add comment documenting expected args format for FP8 inputs.The FP8 scale extraction logic assumes that when
qis FP8, the first three elements of*argsarefp8_scale_q,fp8_scale_k, andfp8_scale_v(lines 1345-1348). This implicit contract is fragile and could lead to bugs if the calling convention changes.🔎 Add clarifying comment
if self._jit_module is not None: run_args.extend(list(args)) else: # Extract FP8 scale tensors from *args if q is FP8 + # Expected format: args[0]=fp8_scale_q, args[1]=fp8_scale_k, args[2]=fp8_scale_v fp8_scale_q = None fp8_scale_k = None fp8_scale_v = None if is_float8(q) and len(args) >= 3: fp8_scale_q = args[0] fp8_scale_k = args[1] fp8_scale_v = args[2]
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
flashinfer/decode.pyflashinfer/jit/attention/modules.pyflashinfer/prefill.pytests/attention/test_hopper_fp8_attention.pytests/utils/test_jit_example.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/utils/test_jit_example.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_hopper_fp8_attention.py (3)
flashinfer/utils.py (1)
is_sm90a_supported(531-533)flashinfer/decode.py (7)
BatchDecodeWithPagedKVCacheWrapper(593-1458)plan(824-1128)plan(1653-1766)run(1158-1171)run(1174-1187)run(1190-1422)run(1769-1893)benchmarks/bench_hopper_fp8_attention.py (1)
per_head_symmetric_quant(27-38)
flashinfer/prefill.py (2)
flashinfer/logits_processor/types.py (1)
dtype(126-130)flashinfer/utils.py (1)
is_float8(158-159)
🪛 Ruff (0.14.10)
flashinfer/prefill.py
2118-2120: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (13)
flashinfer/jit/attention/modules.py (1)
987-992: LGTM: Good defensive validation for backend and output dtype.The assertions appropriately prevent invalid configurations from reaching the backend-specific code generation logic. The error messages are clear and actionable.
flashinfer/prefill.py (2)
1705-1705: LGTM: Documentation correctly expanded to include bfloat16.The updated guidance appropriately reflects that both
torch.float16andtorch.bfloat16are valid output data types when using FP8 inputs.
2116-2120: Good dtype consistency check between plan and run phases.The validation ensures the output tensor dtype matches what was specified during planning, which helps catch configuration errors early. The error message clearly indicates the mismatch.
Note: The static analysis hint (TRY003) about message length is a minor style suggestion and can be safely ignored here since the error context is specific to the runtime values.
flashinfer/decode.py (6)
886-888: LGTM! Clear documentation for FP8 output dtype handling.The documentation clearly explains that
o_data_typedefaults toq_data_typeand provides guidance for FP8 inputs (use torch.float16 or torch.bfloat16). This is a well-designed API addition.
974-987: LGTM! Proper o_data_type handling and caching.The code correctly canonicalizes
o_data_typeand defaults it toq_data_typewhen None, consistent with the documentation. Caching in_cached_o_data_typeenables therun()method to use the planned output dtype.
1042-1050: LGTM! Backend auto-selection logic is sound.The code appropriately calls
determine_attention_backendto resolvebackend="auto"to either"fa2"or"fa3"based on device capabilities and encoding mode. The resolved backend is stored inself._backendfor subsequent use.
1307-1313: LGTM! Safe output tensor creation with fallback.The output tensor creation correctly uses
_cached_o_data_typeif available (set byplan()), with a safe fallback toq.dtype. This ensures FP8 inputs can produce fp16/bf16 outputs as specified.
1414-1420: LGTM! Efficient v_scale handling skips unnecessary multiplication.The optimization to skip v_scale multiplication when
v_scale == 1.0avoids unnecessary computation. The exact equality check is safe for the common case ofv_scale=1.0.
720-732: The concern aboutbackend="auto"being passed withjit_argsis valid but already properly handled. Thegen_customize_batch_prefill_modulefunction includes explicit validation that raises a clearValueErrorwith the message "backend should not be auto when jit_args is provided" if "auto" is passed. This is intentional behavior: JIT modules require an explicit backend specification (fa2, fa3, etc.) rather than automatic resolution. No changes needed.tests/attention/test_hopper_fp8_attention.py (4)
667-684: LGTM! Well-documented test for FP8 decode with paged KV cache.The test function is clearly documented and uses a GQA configuration to thoroughly test head mapping logic with FP8 quantized KV cache. The comment explaining that BatchDecodeWithPagedKVCacheWrapper uses the prefill backend under the hood is helpful context.
735-781: LGTM! Proper FP8 validation against FP16 reference.The test correctly establishes an FP16 baseline and compares it against the FP8 output. Both wrappers use consistent configuration (backend="fa3", use_tensor_cores=True), ensuring a fair comparison. The per-head varying KV data helps detect head offset bugs in the FP8 kernel.
948-956: LGTM! Test properly invoked in main block.The test invocation correctly exercises the decode path with appropriate parameters (GQA configuration with num_qo_heads=8, num_kv_heads=2).
783-785: Verify the strict MSE threshold doesn't cause test flakiness.The MSE threshold of
0.01is 100x stricter than the1.0used in prefill tests (e.g., line 524). While decode operations with single query tokens may have lower quantization error, such a tight threshold could cause flakiness across different hardware (A100 vs H100), CUDA versions, or random seeds.Consider whether
0.01is reliably achievable or if it should be relaxed to0.1or0.5for better test stability.#!/bin/bash # Description: Run the test multiple times to check for flakiness # Expected: Consistent pass rate across runs for i in {1..10}; do echo "Run $i:" python -m pytest tests/attention/test_hopper_fp8_attention.py::test_batch_decode_paged -v done
|
/bot run |
|
[SUCCESS] Pipeline #40946323: 12/20 passed |
|
@bkryu could you merge this? thanks! |
📌 Description
#2111 already enabled Hopper FA3 FP8 attention in
prefill.py. This is just a follow-up PR to make the same change indecode.pybecausedecode.pyactually uses prefill kernels.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.