Skip to content

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Dec 20, 2025

📌 Description

We currently don't define all in our Python modules. Adding it will make public APIs explicit and ensure documentation tools index all intended exports.
This PR also adds missing APIs from documentation index.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Documentation

    • Added POD (Prefix-Only Decode), Unified AllReduce Fusion, FP8 quantization, expanded GEMM and normalization docs; documented new POD and AllReduce API entries.
  • New Features

    • Documented new POD and AllReduce APIs and expanded kernel/GEMM and low-latency GEMM surface.
  • Chores

    • Broadened explicit public API exposure across many modules; conditional CuTe-DSL exports; renamed a documented type to QuantizationSFLayout.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 20, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

This PR adds explicit all exports across many modules, updates docs to document POD, GEMM, FP8, and AllReduce fusion APIs, and adds conditional CuTe-DSL GEMM imports/exports. Two POD wrapper classes are documented and exported; no runtime behavior changes.

Changes

Cohort / File(s) Change Summary
Documentation updates
docs/api/attention.rst, docs/api/comm.rst, docs/api/gemm.rst, docs/api/norm.rst, docs/api/quantization.rst
Added autosummary entries and new sections for POD (flashinfer.pod), CuTe-DSL / low-latency GEMM, FP8 quantization, and Unified AllReduce Fusion; new documented public symbols.
Top-level conditional GEMM exports
flashinfer/__init__.py, flashinfer/gemm/__init__.py
Added CuTe-DSL import guard and conditional exports for grouped_gemm_nt_masked and Sm100BlockScaledPersistentDenseGemmKernel.
POD & Decoding
flashinfer/pod.py, flashinfer/decode.py
Added __all__ exporting PODWithPagedKVCacheWrapper and BatchPODWithPagedKVCacheWrapper; added fast_decode_plan to decode exports and docs.
Subpackage re-exports
flashinfer/comm/__init__.py, flashinfer/jit/__init__.py, flashinfer/logits_processor/__init__.py, flashinfer/testing/__init__.py
Introduced comprehensive __all__ lists to explicitly define public surfaces and re-exports for these subpackages.
Attention / MLA / Sparse / XQA
flashinfer/attention.py, flashinfer/mla.py, flashinfer/sparse.py, flashinfer/xqa.py
Added __all__ declarations exposing wrapper classes and entrypoints (Batch* wrappers, xqa, xqa_mla).
Quantization / FP4 / FP8 / Norm
flashinfer/fp4_quantization.py, flashinfer/fp8_quantization.py, flashinfer/quantization.py, flashinfer/norm.py, docs/api/quantization.rst
Added __all__ lists exposing quantize/dequantize functions, SfLayout, FP4/FP8 helpers, and new norm-related symbols; docs updated for FP8.
Utilities & core modules
flashinfer/activation.py, flashinfer/aot.py, flashinfer/api_logging.py, flashinfer/artifacts.py, flashinfer/autotuner.py, flashinfer/cascade.py, flashinfer/compilation_context.py, flashinfer/concat_ops.py, flashinfer/cuda_utils.py, flashinfer/deep_gemm.py, flashinfer/green_ctx.py, flashinfer/page.py, flashinfer/prefill.py, flashinfer/rope.py, flashinfer/sampling.py, flashinfer/topk.py, flashinfer/trtllm_low_latency_gemm.py, flashinfer/utils.py, flashinfer/version.py, flashinfer/tllm_utils.py
Added __all__ declarations across many utility and core modules to make functions and wrapper classes explicitly exportable. Several files contain duplicate __all__ insertions that should be consolidated.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Areas to focus:
    • Verify each name in every __all__ exists and is spelled correctly.
    • Inspect conditional CuTe-DSL import logic and ensure safe failure/visibility.
    • Consolidate or confirm intentional duplicate __all__ blocks in files (e.g., aot.py, green_ctx.py, rope.py, utils.py, fp8_quantization.py).
    • Ensure docs entries (POD, GEMM, FP8, AllReduce fusion) match exported symbols.

Possibly related PRs

Suggested reviewers

  • aleozlx
  • cyx-6
  • djmmoss
  • wenscarl
  • nvmbreughe
  • bkryu

Poem

🐰 I hopped through modules, tidy and spry,

Names in all now reach for the sky.
POD and GEMM take their place in the light,
Optional kernels wait patient and bright.
A tiny carrot of docs — what a delight.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'chore: add all exports to Python modules and document missing APIs' clearly and concisely summarizes the main change: adding all exports and documenting APIs.
Description check ✅ Passed The PR description explains the purpose ('Adding all will make public APIs explicit and ensure documentation tools index all intended exports') and mentions updating documentation, but pre-commit checklist items and test checklist items are left unchecked.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the discoverability and clarity of the FlashInfer library's public interface. By systematically adding __all__ declarations to Python modules, it explicitly delineates which components are part of the public API. Concurrently, the documentation has been updated to reflect these public APIs, ensuring comprehensive indexing and improved accessibility for users and documentation tools.

Highlights

  • Explicit API Exports: Added __all__ lists to numerous Python modules across the flashinfer library. This explicitly defines the public API of each module, improving code clarity, maintainability, and preventing unintended imports.
  • Enhanced Documentation Coverage: Updated reStructuredText (.rst) documentation files to include previously missing APIs. This ensures that all intended functions, classes, and modules are properly indexed and accessible through the generated documentation, making the library easier to understand and use.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request systematically adds __all__ exports to various Python modules to make the public API explicit, and also updates the documentation to include missing APIs. The changes are well-structured and improve the maintainability of the library. I have a few minor suggestions to improve consistency and correctness of the __all__ definitions in some modules.

Comment on lines 1187 to 1189
__all__ = [
"next_positive_power_of_2",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __all__ list seems incomplete. It only exports next_positive_power_of_2. Many other functions and classes in this module appear to be part of the public API, such as PosEncodingMode, MaskMode, TensorLayout, device_support_pdl, get_compute_capability, and FP4Tensor. Please review and add other public symbols to __all__ to make the public API explicit as intended by this PR.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (8)
flashinfer/concat_ops.py (1)

85-88: Sort __all__ in alphabetical order.

Ruff (RUF022) flags that the items in __all__ should be sorted alphabetically per isort conventions.

🔎 Proposed fix to sort `__all__`
 __all__ = [
+    "concat_mla_k",
     "get_concat_mla_module",
-    "concat_mla_k",
 ]
flashinfer/artifacts.py (1)

253-265: Sort __all__ alphabetically to comply with linting rules.

The __all__ list is complete and correctly exports all public APIs; however, static analysis (Ruff RUF022) flags that __all__ should be sorted alphabetically. If Ruff is enforced in your CI pipeline, this may trigger a linting failure.

🔎 Proposed alphabetically sorted `__all__`
 __all__ = [
-    # Classes
     "ArtifactPath",
     "CheckSumHash",
-    # Functions
+    "clear_cubin",
+    "download_artifacts",
+    "get_artifacts_status",
     "get_available_cubin_files",
     "get_checksums",
     "get_subdir_file_list",
-    "download_artifacts",
-    "get_artifacts_status",
-    "clear_cubin",
+    "temp_env_var",
 ]

Confirm whether Ruff (0.14.8) is enforced as a pre-commit or CI check in your project.

flashinfer/green_ctx.py (1)

298-307: Sort __all__ alphabetically.

The static analysis tool flagged that __all__ is not sorted alphabetically, which is a common convention for Python modules. The current order mixes function name prefixes rather than maintaining strict alphabetical order.

🔎 Proposed sorted `__all__` list
 __all__ = [
+    "create_green_ctx_streams",
+    "get_cudevice",
+    "get_device_resource",
     "get_sm_count_constraint",
-    "get_cudevice",
-    "get_device_resource",
     "split_resource",
-    "split_resource_by_sm_count",
-    "create_green_ctx_streams",
     "split_device_green_ctx",
     "split_device_green_ctx_by_sm_count",
+    "split_resource_by_sm_count",
 ]
flashinfer/testing/__init__.py (1)

32-45: Consider alphabetically sorting __all__ for consistency.

The static analysis tool flags that __all__ is not sorted (Ruff RUF022). While the current grouping (attention functions → bench functions → utilities) is logical, alphabetical sorting is a common convention that some linters enforce.

🔎 Suggested sorted version
 __all__ = [
     "attention_flops",
     "attention_flops_with_actual_seq_lens",
     "attention_tb_per_sec",
     "attention_tb_per_sec_with_actual_seq_lens",
     "attention_tflops_per_sec",
     "attention_tflops_per_sec_with_actual_seq_lens",
     "bench_gpu_time",
-    "bench_gpu_time_with_cupti",
     "bench_gpu_time_with_cuda_event",
     "bench_gpu_time_with_cudagraph",
+    "bench_gpu_time_with_cupti",
     "set_seed",
     "sleep_after_kernel_run",
 ]
flashinfer/deep_gemm.py (1)

1614-1622: Consider using English for inline comments.

The comments # Classes - 实际使用的 and # Functions - 实际使用的 are in Chinese. For consistency with the rest of the codebase (which uses English comments and docstrings), consider translating these to English or removing them entirely since the categories are self-evident.

🔎 Suggested fix
 __all__ = [
-    # Classes - 实际使用的
+    # Classes
     "KernelMap",
-    # Functions - 实际使用的
+    # Functions
     "load",
     "load_all",
     "m_grouped_fp8_gemm_nt_contiguous",
     "m_grouped_fp8_gemm_nt_masked",
 ]
flashinfer/tllm_utils.py (1)

15-18: Consider sorting __all__ alphabetically.

Static analysis (RUF022) suggests sorting __all__ for consistency. The alphabetically sorted order would be ["delay_kernel", "get_trtllm_utils_module"].

🔎 Proposed fix
 __all__ = [
+    "delay_kernel",
     "get_trtllm_utils_module",
-    "delay_kernel",
 ]
flashinfer/fp4_quantization.py (1)

1004-1018: FP4/NVFP4/MXFP4 exports look correct; consider sorting for Ruff

The __all__ contents match the main public FP4/NVFP4/MXFP4 helpers defined in this module and provide a clear, consolidated quantization surface.

Ruff is flagging __all__ as unsorted (RUF022). If you care about keeping Ruff clean, you could alphabetize the entries; otherwise this is functionally fine.

Optional: alphabetize __all__ to satisfy RUF022
-__all__ = [
-    "SfLayout",
-    "block_scale_interleave",
-    "nvfp4_block_scale_interleave",
-    "e2m1_and_ufp8sf_scale_to_float",
-    "fp4_quantize",
-    "mxfp4_dequantize_host",
-    "mxfp4_dequantize",
-    "mxfp4_quantize",
-    "nvfp4_quantize",
-    "nvfp4_batched_quantize",
-    "shuffle_matrix_a",
-    "shuffle_matrix_sf_a",
-    "scaled_fp4_grouped_quantize",
-]
+__all__ = [
+    "SfLayout",
+    "block_scale_interleave",
+    "e2m1_and_ufp8sf_scale_to_float",
+    "fp4_quantize",
+    "mxfp4_dequantize",
+    "mxfp4_dequantize_host",
+    "mxfp4_quantize",
+    "nvfp4_batched_quantize",
+    "nvfp4_block_scale_interleave",
+    "nvfp4_quantize",
+    "scaled_fp4_grouped_quantize",
+    "shuffle_matrix_a",
+    "shuffle_matrix_sf_a",
+]
flashinfer/gemm/__init__.py (1)

49-53: Consider sorting __all__ for consistency.

The conditional extension of __all__ is functionally correct. However, for maintainability and consistency with Python conventions, consider sorting the entire __all__ list alphabetically.

🔎 Optional refactor to sort __all__
 __all__ = [
+    "batch_deepgemm_fp8_nt_groupwise",
+    "bmm_fp8",
+    "fp8_blockscale_gemm_sm90",
+    "gemm_fp8_nt_blockscaled",
+    "gemm_fp8_nt_groupwise",
+    "group_deepgemm_fp8_nt_groupwise",
+    "group_gemm_fp8_nt_groupwise",
+    "group_gemm_mxfp4_nt_groupwise",
+    "mm_M1_16_K7168_N256",
+    "mm_fp4",
+    "mm_fp8",
     "SegmentGEMMWrapper",
-    "bmm_fp8",
-    "mm_fp4",
-    "mm_fp8",
     "tgv_gemm_sm100",
-    "group_gemm_mxfp4_nt_groupwise",
-    "batch_deepgemm_fp8_nt_groupwise",
-    "group_deepgemm_fp8_nt_groupwise",
-    "gemm_fp8_nt_blockscaled",
-    "gemm_fp8_nt_groupwise",
-    "group_gemm_fp8_nt_groupwise",
-    "fp8_blockscale_gemm_sm90",
-    "mm_M1_16_K7168_N256",
 ]
 
 if _CUTE_DSL_AVAILABLE:
     __all__ += [
+        "grouped_gemm_nt_masked",
+        "Sm100BlockScaledPersistentDenseGemmKernel",
-        "grouped_gemm_nt_masked",
-        "Sm100BlockScaledPersistentDenseGemmKernel",
     ]
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 519671d and 0478a48.

📒 Files selected for processing (41)
  • docs/api/attention.rst (2 hunks)
  • docs/api/comm.rst (2 hunks)
  • docs/api/gemm.rst (2 hunks)
  • docs/api/norm.rst (1 hunks)
  • docs/api/quantization.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/activation.py (1 hunks)
  • flashinfer/aot.py (1 hunks)
  • flashinfer/api_logging.py (1 hunks)
  • flashinfer/artifacts.py (1 hunks)
  • flashinfer/attention.py (1 hunks)
  • flashinfer/autotuner.py (1 hunks)
  • flashinfer/cascade.py (1 hunks)
  • flashinfer/comm/__init__.py (1 hunks)
  • flashinfer/compilation_context.py (1 hunks)
  • flashinfer/concat_ops.py (1 hunks)
  • flashinfer/cuda_utils.py (1 hunks)
  • flashinfer/decode.py (1 hunks)
  • flashinfer/deep_gemm.py (1 hunks)
  • flashinfer/fp4_quantization.py (1 hunks)
  • flashinfer/fp8_quantization.py (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/green_ctx.py (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/logits_processor/__init__.py (1 hunks)
  • flashinfer/mla.py (1 hunks)
  • flashinfer/norm.py (1 hunks)
  • flashinfer/page.py (1 hunks)
  • flashinfer/pod.py (1 hunks)
  • flashinfer/prefill.py (1 hunks)
  • flashinfer/quantization.py (1 hunks)
  • flashinfer/rope.py (1 hunks)
  • flashinfer/sampling.py (1 hunks)
  • flashinfer/sparse.py (1 hunks)
  • flashinfer/testing/__init__.py (1 hunks)
  • flashinfer/tllm_utils.py (1 hunks)
  • flashinfer/topk.py (1 hunks)
  • flashinfer/trtllm_low_latency_gemm.py (1 hunks)
  • flashinfer/utils.py (1 hunks)
  • flashinfer/version.py (1 hunks)
  • flashinfer/xqa.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/gemm/__init__.py (1)
flashinfer/cute_dsl/blockscaled_gemm.py (2)
  • grouped_gemm_nt_masked (2945-3046)
  • Sm100BlockScaledPersistentDenseGemmKernel (464-2449)
flashinfer/__init__.py (2)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h (1)
  • gemm (41-675)
flashinfer/cute_dsl/blockscaled_gemm.py (2)
  • grouped_gemm_nt_masked (2945-3046)
  • Sm100BlockScaledPersistentDenseGemmKernel (464-2449)
🪛 Ruff (0.14.8)
flashinfer/artifacts.py

253-265: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/gemm/__init__.py

50-53: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/jit/__init__.py

96-158: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/fp4_quantization.py

1004-1018: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/logits_processor/__init__.py

36-62: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/concat_ops.py

85-88: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/green_ctx.py

298-307: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/version.py

26-29: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/tllm_utils.py

15-18: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/testing/__init__.py

32-45: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/comm/__init__.py

70-123: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (34)
flashinfer/cuda_utils.py (1)

64-66: The __all__ definition is correct.

The checkCudaErrors function is the only public API that should be exposed. The imported CUDA modules (driver, runtime, cudart, nvrtc) are internal implementation dependencies and are not imported from this module elsewhere in the codebase, so they should remain unexposed.

flashinfer/topk.py (1)

422-428: LGTM!

The __all__ export correctly includes the three main public functions. The topk alias is intentionally omitted (it's just an alias for top_k), and can_implement_filtered_topk appears to be a utility function that may not need public exposure.

flashinfer/sparse.py (1)

1169-1174: LGTM!

The __all__ correctly exports the two main wrapper classes that constitute the public API for block-sparse attention. The helper function convert_bsr_mask_layout is intentionally kept internal.

flashinfer/xqa.py (1)

530-535: LGTM!

The __all__ correctly exports the two public API functions (xqa and xqa_mla) while keeping the internal module getters (get_xqa_module, get_xqa_module_mla) as implementation details.

flashinfer/attention.py (1)

282-285: LGTM!

The __all__ export correctly lists both public classes (BatchAttention and BatchAttentionWithAttentionSinkWrapper) defined in this module.

flashinfer/compilation_context.py (1)

71-73: LGTM!

The __all__ export correctly lists the CompilationContext class, which is the sole public API of this module.

flashinfer/quantization.py (1)

142-145: LGTM!

The __all__ export correctly lists the two public functions (packbits and segment_packbits) while appropriately excluding private helpers prefixed with _.

flashinfer/utils.py (1)

1187-1189: Verify the intended public API scope.

The __all__ exports only next_positive_power_of_2, but this module contains many other utilities that appear public (e.g., FP4Tensor, get_compute_capability, determine_attention_backend, PosEncodingMode, MaskMode, TensorLayout, LogLevel, set_log_level, supported_compute_capability, backend_requirement).

Is this minimal export intentional, or should additional utilities be included in __all__ for documentation tooling to index them?

flashinfer/api_logging.py (1)

568-570: LGTM!

The __all__ export correctly exposes the flashinfer_api decorator as the sole public API, appropriately excluding internal helpers.

flashinfer/autotuner.py (1)

794-796: Verify if additional classes should be exported.

The __all__ exports only the autotune context manager. However, users implementing custom autotuning may need access to classes like TunableRunner, TuningConfig, DynamicTensorSpec, ConstraintSpec, and OptimizationProfile.

Is this minimal export intentional to keep the public API surface small, or should these supporting types be included for documentation and explicit import support?

flashinfer/page.py (1)

384-389: LGTM!

The __all__ declaration correctly exposes the four public APIs defined in this module. All exported functions are properly decorated with @flashinfer_api and align with the module's intended public surface.

flashinfer/version.py (1)

26-29: LGTM!

The __all__ declaration correctly exposes the version metadata. Note that the list is already alphabetically sorted, so the Ruff RUF022 warning can be safely ignored as a false positive.

flashinfer/norm.py (1)

415-423: LGTM with conditional exports noted.

The __all__ declaration correctly exposes the public norm APIs. Note that rmsnorm_fp4quant and add_rmsnorm_fp4quant are conditionally imported (lines 408-413) and will be None if CuTe-DSL is not available. This is intentional design for optional dependencies, as indicated by the type ignore comments.

docs/api/norm.rst (1)

18-19: LGTM!

The documentation update correctly adds the new CuTe-DSL based RMSNorm APIs to the autosummary, aligning with the code changes in flashinfer/norm.py.

flashinfer/activation.py (1)

227-232: LGTM!

The __all__ declaration correctly exposes the four activation function APIs. All exported functions are properly decorated with @flashinfer_api and represent the intended public surface of this module.

flashinfer/comm/__init__.py (1)

70-123: LGTM! Semantic grouping preferred over alphabetical sorting.

The __all__ declaration comprehensively exposes the communication module's public API with clear logical grouping (CUDA IPC, DLPack, TensorRT-LLM AllReduce, etc.). The semantic organization with comments significantly improves maintainability compared to alphabetical sorting, making the Ruff RUF022 warning safely ignorable in this context.

flashinfer/prefill.py (1)

3759-3764: LGTM!

The __all__ declaration correctly exposes the prefill module's public API, including both wrapper classes and standalone functions. The exports align with the module's documented interfaces and intended public surface.

flashinfer/logits_processor/__init__.py (1)

36-62: LGTM! Semantic grouping enhances discoverability.

The __all__ declaration clearly organizes the logits processor module's public API by component type (Compiler, Processors, Types, etc.). This semantic grouping with inline comments provides better developer experience than alphabetical sorting, making the Ruff RUF022 warning appropriately ignorable.

flashinfer/decode.py (1)

2682-2689: Confirm intended public surface for decode-related helpers

The new __all__ only exposes the wrapper classes, cudnn_batch_decode_with_kv_cache, fast_decode_plan, and single_decode_with_kv_cache, but leaves out other public-looking functions in this module such as trtllm_batch_decode_with_kv_cache and xqa_batch_decode_with_kv_cache.

If Sphinx and from flashinfer.decode import * are meant to show those helpers as part of the public API, consider adding them here as well or confirming they’re intentionally kept internal.

flashinfer/cascade.py (1)

1083-1090: Explicit cascade API exports look consistent

The __all__ set cleanly matches the user-facing cascade wrappers and merge helpers defined in this module. This should work well for tooling and from flashinfer.cascade import *.

flashinfer/fp8_quantization.py (1)

211-214: FP8 quantization exports are minimal and focused

Exporting only mxfp8_quantize and mxfp8_dequantize_host matches the intended public FP8 API and avoids leaking internal helper symbols.

flashinfer/aot.py (1)

884-886: Validate that register_default_modules is the only intended public AOT entry

__all__ now exposes only register_default_modules, while leaving helpers like compile_and_package_modules and main non-exported. This is a sensible minimal surface; just confirm that you don’t expect users to import any of the build helpers directly from flashinfer.aot.

flashinfer/sampling.py (1)

1590-1603: Sampling public surface is clearly defined

The __all__ list captures the full set of documented sampling utilities and keeps legacy aliases internal. This should play nicely with docs and from flashinfer.sampling import *.

docs/api/comm.rst (1)

49-49: Ensure documented comm symbols are exported and resolvable

The docs now reference QuantizationSFLayout and the Unified AllReduce Fusion API (AllReduceFusionWorkspace, TRTLLMAllReduceFusionWorkspace, allreduce_fusion, create_allreduce_fusion_workspace, plus MNNVLAllReduceFusionWorkspace under flashinfer.comm.trtllm_mnnvl_ar).

Please double‑check that:

  • These symbols actually exist under the documented modules, and
  • They’re included in the corresponding __all__ so Sphinx autosummary can import and render them without warnings.

Also applies to: 97-116

flashinfer/rope.py (1)

1674-1685: LGTM! Clean public API definition.

The __all__ declaration correctly exposes all public RoPE functions, making the module's API surface explicit for documentation tools.

docs/api/quantization.rst (1)

15-27: LGTM! Well-structured documentation addition.

The FP8 quantization section properly documents the previously missing APIs using appropriate Sphinx directives.

flashinfer/pod.py (1)

1204-1207: LGTM! Correct public API exposure.

The __all__ declaration properly exposes the two POD wrapper classes, making them discoverable by documentation tools.

docs/api/gemm.rst (2)

25-25: LGTM! Completes the FP8 GEMM documentation.

Adding mm_fp8 to the autosummary properly documents this previously missing API.


48-75: LGTM! Well-organized new GEMM sections.

The documentation properly exposes new GEMM variants:

  • Blackwell GEMM for SM100 architecture
  • TensorRT-LLM low-latency GEMM utilities
  • CuTe-DSL GEMM kernels (conditionally available)

The structure is clear and uses appropriate Sphinx directives.

flashinfer/trtllm_low_latency_gemm.py (1)

227-229: LGTM! Proper public API exposure.

The __all__ declaration correctly exports the weight preparation utility, making it accessible for documentation and user consumption.

flashinfer/gemm/__init__.py (1)

22-31: LGTM! Proper handling of optional dependency.

The conditional import pattern correctly handles CuTe-DSL availability, gracefully degrading when the optional dependency is not present. The _CUTE_DSL_AVAILABLE flag provides a clear way to track availability.

docs/api/attention.rst (2)

28-28: LGTM! Documents missing decode planning API.

Adding fast_decode_plan to the autosummary properly exposes this utility function.


114-129: LGTM! Well-structured POD attention documentation.

The new POD (Prefix-Only Decode) section properly documents the wrapper classes with appropriate Sphinx directives, aligning with the __all__ exports in flashinfer/pod.py.

flashinfer/__init__.py (1)

92-100: LGTM! Robust handling of optional CuTe-DSL dependency.

The conditional import pattern properly handles environments where CuTe-DSL is unavailable, allowing the package to function gracefully in both scenarios. This aligns with the availability check in flashinfer/gemm/__init__.py.

Comment on lines +96 to +158
__all__ = [
# Submodules
"cubin_loader",
"env",
# Activation
"gen_act_and_mul_module",
"get_act_and_mul_cu_str",
# Attention
"gen_cudnn_fmha_module",
"gen_batch_attention_module",
"gen_batch_decode_mla_module",
"gen_batch_decode_module",
"gen_batch_mla_module",
"gen_batch_prefill_module",
"gen_customize_batch_decode_module",
"gen_customize_batch_prefill_module",
"gen_customize_single_decode_module",
"gen_customize_single_prefill_module",
"gen_fmha_cutlass_sm100a_module",
"gen_batch_pod_module",
"gen_pod_module",
"gen_single_decode_module",
"gen_single_prefill_module",
"get_batch_attention_uri",
"get_batch_decode_mla_uri",
"get_batch_decode_uri",
"get_batch_mla_uri",
"get_batch_prefill_uri",
"get_pod_uri",
"get_single_decode_uri",
"get_single_prefill_uri",
"gen_trtllm_gen_fmha_module",
"get_trtllm_fmha_v2_module",
# Core
"JitSpec",
"JitSpecStatus",
"JitSpecRegistry",
"jit_spec_registry",
"build_jit_specs",
"clear_cache_dir",
"gen_jit_spec",
"MissingJITCacheError",
"sm90a_nvcc_flags",
"sm100a_nvcc_flags",
"sm100f_nvcc_flags",
"sm103a_nvcc_flags",
"sm110a_nvcc_flags",
"sm120a_nvcc_flags",
"sm121a_nvcc_flags",
"current_compilation_context",
# Cubin Loader
"setup_cubin_loader",
# Comm
"gen_comm_alltoall_module",
"gen_trtllm_mnnvl_comm_module",
"gen_trtllm_comm_module",
"gen_vllm_comm_module",
"gen_nvshmem_module",
"gen_moe_alltoall_module",
# DSv3 Optimizations
"gen_dsv3_router_gemm_module",
"gen_dsv3_fused_routing_module",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Sort __all__ alphabetically per RUF022 rule.

The __all__ list is currently organized by category, but the static analysis tool expects isort-style alphabetical sorting. While categorical organization is more readable, this should be aligned with the project's linting rules.

🔎 Proposed alphabetically sorted `__all__`
 __all__ = [
+    "build_jit_specs",
+    "clear_cache_dir",
+    "current_compilation_context",
+    "cubin_loader",
+    "env",
+    "gen_act_and_mul_module",
+    "gen_batch_allreduce_module",
+    "gen_batch_attention_module",
+    "gen_batch_decode_mla_module",
+    "gen_batch_decode_module",
+    "gen_batch_mla_module",
+    "gen_batch_pod_module",
+    "gen_batch_prefill_module",
+    "gen_comm_alltoall_module",
+    "gen_cudnn_fmha_module",
+    "gen_customize_batch_decode_module",
+    "gen_customize_batch_prefill_module",
+    "gen_customize_single_decode_module",
+    "gen_customize_single_prefill_module",
+    "gen_dsv3_fused_routing_module",
+    "gen_dsv3_router_gemm_module",
+    "gen_fmha_cutlass_sm100a_module",
+    "gen_moe_alltoall_module",
+    "gen_nvshmem_module",
+    "gen_pod_module",
+    "gen_single_decode_module",
+    "gen_single_prefill_module",
+    "gen_trtllm_comm_module",
+    "gen_trtllm_gen_fmha_module",
+    "gen_trtllm_mnnvl_comm_module",
+    "gen_vllm_comm_module",
+    "get_act_and_mul_cu_str",
+    "get_batch_attention_uri",
+    "get_batch_decode_mla_uri",
+    "get_batch_decode_uri",
+    "get_batch_mla_uri",
+    "get_batch_prefill_uri",
+    "get_pod_uri",
+    "get_single_decode_uri",
+    "get_single_prefill_uri",
+    "get_trtllm_fmha_v2_module",
+    "JitSpec",
+    "JitSpecRegistry",
+    "JitSpecStatus",
+    "jit_spec_registry",
+    "MissingJITCacheError",
+    "setup_cubin_loader",
+    "sm100a_nvcc_flags",
+    "sm100f_nvcc_flags",
+    "sm103a_nvcc_flags",
+    "sm110a_nvcc_flags",
+    "sm120a_nvcc_flags",
+    "sm121a_nvcc_flags",
+    "sm90a_nvcc_flags",
 ]

Note: Capitalized entries (JitSpec*, MissingJITCacheError) and those starting with lowercase will naturally sort together; Python's default sort treats uppercase before lowercase. If case-insensitive grouping is preferred, consider using a custom sort key.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.8)

96-158: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

🤖 Prompt for AI Agents
In flashinfer/jit/__init__.py around lines 96 to 158, the __all__ list is
grouped by category but must be alphabetically sorted to satisfy
RUF022/isort-style linting; reorder all string entries in the list into a single
alphabetical sequence (case-sensitive default sort is fine), keeping each name
as a quoted string with commas and preserving the list formatting (multiline)
and trailing comma.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
flashinfer/norm.py (1)

415-428: Fix critical syntax error in __all__ definition.

The code has a duplicate and malformed __all__ definition causing syntax errors that break pre-commit checks. Lines 415-417 start two separate __all__ assignments, with the first one unclosed.

A past review comment already suggested the correct pattern (see lines 415-423 in past comments). Please follow that suggestion.

🔎 Proposed fix
-__all__ = [
-    "add_rmsnorm_fp4quant",
-__all__ = [
+__all__ = [
     "fused_add_rmsnorm",
     "gemma_fused_add_rmsnorm",
     "gemma_rmsnorm",
     "layernorm",
     "rmsnorm",
 ]
 
 if add_rmsnorm_fp4quant is not None:
     __all__.append("add_rmsnorm_fp4quant")
 if rmsnorm_fp4quant is not None:
     __all__.append("rmsnorm_fp4quant")
🧹 Nitpick comments (1)
flashinfer/norm.py (1)

418-423: Consider adding rmsnorm_quant and fused_add_rmsnorm_quant to __all__.

Both rmsnorm_quant (line 97) and fused_add_rmsnorm_quant (line 198) are decorated with @flashinfer_api, indicating they're part of the public API. They should be included in __all__ for consistency and to support wildcard imports.

🔎 Suggested addition
 __all__ = [
     "fused_add_rmsnorm",
+    "fused_add_rmsnorm_quant",
     "gemma_fused_add_rmsnorm",
     "gemma_rmsnorm",
     "layernorm",
     "rmsnorm",
+    "rmsnorm_quant",
 ]
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fd99f6b and c350bf3.

📒 Files selected for processing (1)
  • flashinfer/norm.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/norm.py (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
  • add_rmsnorm_fp4quant (2219-2327)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
  • rmsnorm_fp4quant (1790-1895)
🪛 GitHub Actions: pre-commit
flashinfer/norm.py

[error] 415-429: Syntax errors detected in flashinfer/norm.py (e.g., duplicate all definition, an unclosed bracket, and invalid-syntax causing mypy/ruff to fail). Please fix the Python syntax to resolve the pre-commit checks.

🪛 Ruff (0.14.8)
flashinfer/norm.py

417-417: Expected ,, found =

(invalid-syntax)


425-425: Expected else, found :

(invalid-syntax)


427-427: Expected else, found :

(invalid-syntax)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
flashinfer/utils.py (2)

1246-1257: Consider whether internal helpers should be in __all__.

The underscore-prefixed functions are typically considered private/internal by Python convention. While the comment indicates they're "used across modules," this usually means internal package usage rather than public API exposure. If these are only for inter-module communication within flashinfer, they can be imported directly by other modules without being in __all__, keeping the public API surface smaller and more focused.

However, if these are intended as a stable, semi-public API for advanced users, then including them is acceptable.


1187-1257: Optional: Consider sorting __all__ for consistency.

The static analysis tool suggests applying isort-style sorting to the __all__ list. The current semantic grouping by category (Enums, Exceptions, etc.) has its advantages for readability, but alphabetical sorting would improve searchability and align with common Python conventions. This is a minor stylistic preference.

flashinfer/mla.py (1)

802-809: @flashinfer_api functions correctly added; clarify intent for module getter functions.

The inclusion of trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla with @flashinfer_api decoration addresses the prior review feedback and correctly exposes these primary public APIs.

However, the three module getter functions (get_trtllm_gen_fmha_module, get_mla_module, get_batch_mla_module) are decorated only with @functools.cache, lack docstrings, and are used exclusively within mla.py for internal module caching. Clarify whether these are intended as public exports or if they should remain internal implementation details.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 35f9bf9 and 856fb36.

📒 Files selected for processing (2)
  • flashinfer/mla.py (1 hunks)
  • flashinfer/utils.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.14.8)
flashinfer/utils.py

1187-1257: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

flashinfer/mla.py

802-809: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/utils.py (1)

1187-1257: Comprehensive __all__ export list successfully addresses PR objective.

The addition of this explicit __all__ list makes the public API surface clear and helps documentation tools properly index the intended exports. The organization with category comments (Enums, Exceptions, Classes, etc.) enhances maintainability. This addresses the concerns raised in previous reviews about missing public symbols.

@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 22, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !209 has been created, and the CI pipeline #40597907 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #40597907: 12/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants