-
Notifications
You must be signed in to change notification settings - Fork 621
chore: add __all__ exports to Python modules and document missing APIs #2251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis PR adds explicit all exports across many modules, updates docs to document POD, GEMM, FP8, and AllReduce fusion APIs, and adds conditional CuTe-DSL GEMM imports/exports. Two POD wrapper classes are documented and exported; no runtime behavior changes. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the discoverability and clarity of the FlashInfer library's public interface. By systematically adding Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request systematically adds __all__ exports to various Python modules to make the public API explicit, and also updates the documentation to include missing APIs. The changes are well-structured and improve the maintainability of the library. I have a few minor suggestions to improve consistency and correctness of the __all__ definitions in some modules.
| __all__ = [ | ||
| "next_positive_power_of_2", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __all__ list seems incomplete. It only exports next_positive_power_of_2. Many other functions and classes in this module appear to be part of the public API, such as PosEncodingMode, MaskMode, TensorLayout, device_support_pdl, get_compute_capability, and FP4Tensor. Please review and add other public symbols to __all__ to make the public API explicit as intended by this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (8)
flashinfer/concat_ops.py (1)
85-88: Sort__all__in alphabetical order.Ruff (RUF022) flags that the items in
__all__should be sorted alphabetically per isort conventions.🔎 Proposed fix to sort `__all__`
__all__ = [ + "concat_mla_k", "get_concat_mla_module", - "concat_mla_k", ]flashinfer/artifacts.py (1)
253-265: Sort__all__alphabetically to comply with linting rules.The
__all__list is complete and correctly exports all public APIs; however, static analysis (Ruff RUF022) flags that__all__should be sorted alphabetically. If Ruff is enforced in your CI pipeline, this may trigger a linting failure.🔎 Proposed alphabetically sorted `__all__`
__all__ = [ - # Classes "ArtifactPath", "CheckSumHash", - # Functions + "clear_cubin", + "download_artifacts", + "get_artifacts_status", "get_available_cubin_files", "get_checksums", "get_subdir_file_list", - "download_artifacts", - "get_artifacts_status", - "clear_cubin", + "temp_env_var", ]Confirm whether Ruff (0.14.8) is enforced as a pre-commit or CI check in your project.
flashinfer/green_ctx.py (1)
298-307: Sort__all__alphabetically.The static analysis tool flagged that
__all__is not sorted alphabetically, which is a common convention for Python modules. The current order mixes function name prefixes rather than maintaining strict alphabetical order.🔎 Proposed sorted `__all__` list
__all__ = [ + "create_green_ctx_streams", + "get_cudevice", + "get_device_resource", "get_sm_count_constraint", - "get_cudevice", - "get_device_resource", "split_resource", - "split_resource_by_sm_count", - "create_green_ctx_streams", "split_device_green_ctx", "split_device_green_ctx_by_sm_count", + "split_resource_by_sm_count", ]flashinfer/testing/__init__.py (1)
32-45: Consider alphabetically sorting__all__for consistency.The static analysis tool flags that
__all__is not sorted (Ruff RUF022). While the current grouping (attention functions → bench functions → utilities) is logical, alphabetical sorting is a common convention that some linters enforce.🔎 Suggested sorted version
__all__ = [ "attention_flops", "attention_flops_with_actual_seq_lens", "attention_tb_per_sec", "attention_tb_per_sec_with_actual_seq_lens", "attention_tflops_per_sec", "attention_tflops_per_sec_with_actual_seq_lens", "bench_gpu_time", - "bench_gpu_time_with_cupti", "bench_gpu_time_with_cuda_event", "bench_gpu_time_with_cudagraph", + "bench_gpu_time_with_cupti", "set_seed", "sleep_after_kernel_run", ]flashinfer/deep_gemm.py (1)
1614-1622: Consider using English for inline comments.The comments
# Classes - 实际使用的and# Functions - 实际使用的are in Chinese. For consistency with the rest of the codebase (which uses English comments and docstrings), consider translating these to English or removing them entirely since the categories are self-evident.🔎 Suggested fix
__all__ = [ - # Classes - 实际使用的 + # Classes "KernelMap", - # Functions - 实际使用的 + # Functions "load", "load_all", "m_grouped_fp8_gemm_nt_contiguous", "m_grouped_fp8_gemm_nt_masked", ]flashinfer/tllm_utils.py (1)
15-18: Consider sorting__all__alphabetically.Static analysis (RUF022) suggests sorting
__all__for consistency. The alphabetically sorted order would be["delay_kernel", "get_trtllm_utils_module"].🔎 Proposed fix
__all__ = [ + "delay_kernel", "get_trtllm_utils_module", - "delay_kernel", ]flashinfer/fp4_quantization.py (1)
1004-1018: FP4/NVFP4/MXFP4 exports look correct; consider sorting for RuffThe
__all__contents match the main public FP4/NVFP4/MXFP4 helpers defined in this module and provide a clear, consolidated quantization surface.Ruff is flagging
__all__as unsorted (RUF022). If you care about keeping Ruff clean, you could alphabetize the entries; otherwise this is functionally fine.Optional: alphabetize
__all__to satisfy RUF022-__all__ = [ - "SfLayout", - "block_scale_interleave", - "nvfp4_block_scale_interleave", - "e2m1_and_ufp8sf_scale_to_float", - "fp4_quantize", - "mxfp4_dequantize_host", - "mxfp4_dequantize", - "mxfp4_quantize", - "nvfp4_quantize", - "nvfp4_batched_quantize", - "shuffle_matrix_a", - "shuffle_matrix_sf_a", - "scaled_fp4_grouped_quantize", -] +__all__ = [ + "SfLayout", + "block_scale_interleave", + "e2m1_and_ufp8sf_scale_to_float", + "fp4_quantize", + "mxfp4_dequantize", + "mxfp4_dequantize_host", + "mxfp4_quantize", + "nvfp4_batched_quantize", + "nvfp4_block_scale_interleave", + "nvfp4_quantize", + "scaled_fp4_grouped_quantize", + "shuffle_matrix_a", + "shuffle_matrix_sf_a", +]flashinfer/gemm/__init__.py (1)
49-53: Consider sorting__all__for consistency.The conditional extension of
__all__is functionally correct. However, for maintainability and consistency with Python conventions, consider sorting the entire__all__list alphabetically.🔎 Optional refactor to sort __all__
__all__ = [ + "batch_deepgemm_fp8_nt_groupwise", + "bmm_fp8", + "fp8_blockscale_gemm_sm90", + "gemm_fp8_nt_blockscaled", + "gemm_fp8_nt_groupwise", + "group_deepgemm_fp8_nt_groupwise", + "group_gemm_fp8_nt_groupwise", + "group_gemm_mxfp4_nt_groupwise", + "mm_M1_16_K7168_N256", + "mm_fp4", + "mm_fp8", "SegmentGEMMWrapper", - "bmm_fp8", - "mm_fp4", - "mm_fp8", "tgv_gemm_sm100", - "group_gemm_mxfp4_nt_groupwise", - "batch_deepgemm_fp8_nt_groupwise", - "group_deepgemm_fp8_nt_groupwise", - "gemm_fp8_nt_blockscaled", - "gemm_fp8_nt_groupwise", - "group_gemm_fp8_nt_groupwise", - "fp8_blockscale_gemm_sm90", - "mm_M1_16_K7168_N256", ] if _CUTE_DSL_AVAILABLE: __all__ += [ + "grouped_gemm_nt_masked", + "Sm100BlockScaledPersistentDenseGemmKernel", - "grouped_gemm_nt_masked", - "Sm100BlockScaledPersistentDenseGemmKernel", ]
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (41)
docs/api/attention.rst(2 hunks)docs/api/comm.rst(2 hunks)docs/api/gemm.rst(2 hunks)docs/api/norm.rst(1 hunks)docs/api/quantization.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/activation.py(1 hunks)flashinfer/aot.py(1 hunks)flashinfer/api_logging.py(1 hunks)flashinfer/artifacts.py(1 hunks)flashinfer/attention.py(1 hunks)flashinfer/autotuner.py(1 hunks)flashinfer/cascade.py(1 hunks)flashinfer/comm/__init__.py(1 hunks)flashinfer/compilation_context.py(1 hunks)flashinfer/concat_ops.py(1 hunks)flashinfer/cuda_utils.py(1 hunks)flashinfer/decode.py(1 hunks)flashinfer/deep_gemm.py(1 hunks)flashinfer/fp4_quantization.py(1 hunks)flashinfer/fp8_quantization.py(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/green_ctx.py(1 hunks)flashinfer/jit/__init__.py(1 hunks)flashinfer/logits_processor/__init__.py(1 hunks)flashinfer/mla.py(1 hunks)flashinfer/norm.py(1 hunks)flashinfer/page.py(1 hunks)flashinfer/pod.py(1 hunks)flashinfer/prefill.py(1 hunks)flashinfer/quantization.py(1 hunks)flashinfer/rope.py(1 hunks)flashinfer/sampling.py(1 hunks)flashinfer/sparse.py(1 hunks)flashinfer/testing/__init__.py(1 hunks)flashinfer/tllm_utils.py(1 hunks)flashinfer/topk.py(1 hunks)flashinfer/trtllm_low_latency_gemm.py(1 hunks)flashinfer/utils.py(1 hunks)flashinfer/version.py(1 hunks)flashinfer/xqa.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/gemm/__init__.py (1)
flashinfer/cute_dsl/blockscaled_gemm.py (2)
grouped_gemm_nt_masked(2945-3046)Sm100BlockScaledPersistentDenseGemmKernel(464-2449)
flashinfer/__init__.py (2)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h (1)
gemm(41-675)flashinfer/cute_dsl/blockscaled_gemm.py (2)
grouped_gemm_nt_masked(2945-3046)Sm100BlockScaledPersistentDenseGemmKernel(464-2449)
🪛 Ruff (0.14.8)
flashinfer/artifacts.py
253-265: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/gemm/__init__.py
50-53: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/jit/__init__.py
96-158: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/fp4_quantization.py
1004-1018: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/logits_processor/__init__.py
36-62: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/concat_ops.py
85-88: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/green_ctx.py
298-307: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/version.py
26-29: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/tllm_utils.py
15-18: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/testing/__init__.py
32-45: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/comm/__init__.py
70-123: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (34)
flashinfer/cuda_utils.py (1)
64-66: The__all__definition is correct.The
checkCudaErrorsfunction is the only public API that should be exposed. The imported CUDA modules (driver,runtime,cudart,nvrtc) are internal implementation dependencies and are not imported from this module elsewhere in the codebase, so they should remain unexposed.flashinfer/topk.py (1)
422-428: LGTM!The
__all__export correctly includes the three main public functions. Thetopkalias is intentionally omitted (it's just an alias fortop_k), andcan_implement_filtered_topkappears to be a utility function that may not need public exposure.flashinfer/sparse.py (1)
1169-1174: LGTM!The
__all__correctly exports the two main wrapper classes that constitute the public API for block-sparse attention. The helper functionconvert_bsr_mask_layoutis intentionally kept internal.flashinfer/xqa.py (1)
530-535: LGTM!The
__all__correctly exports the two public API functions (xqaandxqa_mla) while keeping the internal module getters (get_xqa_module,get_xqa_module_mla) as implementation details.flashinfer/attention.py (1)
282-285: LGTM!The
__all__export correctly lists both public classes (BatchAttentionandBatchAttentionWithAttentionSinkWrapper) defined in this module.flashinfer/compilation_context.py (1)
71-73: LGTM!The
__all__export correctly lists theCompilationContextclass, which is the sole public API of this module.flashinfer/quantization.py (1)
142-145: LGTM!The
__all__export correctly lists the two public functions (packbitsandsegment_packbits) while appropriately excluding private helpers prefixed with_.flashinfer/utils.py (1)
1187-1189: Verify the intended public API scope.The
__all__exports onlynext_positive_power_of_2, but this module contains many other utilities that appear public (e.g.,FP4Tensor,get_compute_capability,determine_attention_backend,PosEncodingMode,MaskMode,TensorLayout,LogLevel,set_log_level,supported_compute_capability,backend_requirement).Is this minimal export intentional, or should additional utilities be included in
__all__for documentation tooling to index them?flashinfer/api_logging.py (1)
568-570: LGTM!The
__all__export correctly exposes theflashinfer_apidecorator as the sole public API, appropriately excluding internal helpers.flashinfer/autotuner.py (1)
794-796: Verify if additional classes should be exported.The
__all__exports only theautotunecontext manager. However, users implementing custom autotuning may need access to classes likeTunableRunner,TuningConfig,DynamicTensorSpec,ConstraintSpec, andOptimizationProfile.Is this minimal export intentional to keep the public API surface small, or should these supporting types be included for documentation and explicit import support?
flashinfer/page.py (1)
384-389: LGTM!The
__all__declaration correctly exposes the four public APIs defined in this module. All exported functions are properly decorated with@flashinfer_apiand align with the module's intended public surface.flashinfer/version.py (1)
26-29: LGTM!The
__all__declaration correctly exposes the version metadata. Note that the list is already alphabetically sorted, so the Ruff RUF022 warning can be safely ignored as a false positive.flashinfer/norm.py (1)
415-423: LGTM with conditional exports noted.The
__all__declaration correctly exposes the public norm APIs. Note thatrmsnorm_fp4quantandadd_rmsnorm_fp4quantare conditionally imported (lines 408-413) and will beNoneif CuTe-DSL is not available. This is intentional design for optional dependencies, as indicated by the type ignore comments.docs/api/norm.rst (1)
18-19: LGTM!The documentation update correctly adds the new CuTe-DSL based RMSNorm APIs to the autosummary, aligning with the code changes in
flashinfer/norm.py.flashinfer/activation.py (1)
227-232: LGTM!The
__all__declaration correctly exposes the four activation function APIs. All exported functions are properly decorated with@flashinfer_apiand represent the intended public surface of this module.flashinfer/comm/__init__.py (1)
70-123: LGTM! Semantic grouping preferred over alphabetical sorting.The
__all__declaration comprehensively exposes the communication module's public API with clear logical grouping (CUDA IPC, DLPack, TensorRT-LLM AllReduce, etc.). The semantic organization with comments significantly improves maintainability compared to alphabetical sorting, making the Ruff RUF022 warning safely ignorable in this context.flashinfer/prefill.py (1)
3759-3764: LGTM!The
__all__declaration correctly exposes the prefill module's public API, including both wrapper classes and standalone functions. The exports align with the module's documented interfaces and intended public surface.flashinfer/logits_processor/__init__.py (1)
36-62: LGTM! Semantic grouping enhances discoverability.The
__all__declaration clearly organizes the logits processor module's public API by component type (Compiler, Processors, Types, etc.). This semantic grouping with inline comments provides better developer experience than alphabetical sorting, making the Ruff RUF022 warning appropriately ignorable.flashinfer/decode.py (1)
2682-2689: Confirm intended public surface for decode-related helpersThe new
__all__only exposes the wrapper classes,cudnn_batch_decode_with_kv_cache,fast_decode_plan, andsingle_decode_with_kv_cache, but leaves out other public-looking functions in this module such astrtllm_batch_decode_with_kv_cacheandxqa_batch_decode_with_kv_cache.If Sphinx and
from flashinfer.decode import *are meant to show those helpers as part of the public API, consider adding them here as well or confirming they’re intentionally kept internal.flashinfer/cascade.py (1)
1083-1090: Explicit cascade API exports look consistentThe
__all__set cleanly matches the user-facing cascade wrappers and merge helpers defined in this module. This should work well for tooling andfrom flashinfer.cascade import *.flashinfer/fp8_quantization.py (1)
211-214: FP8 quantization exports are minimal and focusedExporting only
mxfp8_quantizeandmxfp8_dequantize_hostmatches the intended public FP8 API and avoids leaking internal helper symbols.flashinfer/aot.py (1)
884-886: Validate thatregister_default_modulesis the only intended public AOT entry
__all__now exposes onlyregister_default_modules, while leaving helpers likecompile_and_package_modulesandmainnon-exported. This is a sensible minimal surface; just confirm that you don’t expect users to import any of the build helpers directly fromflashinfer.aot.flashinfer/sampling.py (1)
1590-1603: Sampling public surface is clearly definedThe
__all__list captures the full set of documented sampling utilities and keeps legacy aliases internal. This should play nicely with docs andfrom flashinfer.sampling import *.docs/api/comm.rst (1)
49-49: Ensure documented comm symbols are exported and resolvableThe docs now reference
QuantizationSFLayoutand the Unified AllReduce Fusion API (AllReduceFusionWorkspace,TRTLLMAllReduceFusionWorkspace,allreduce_fusion,create_allreduce_fusion_workspace, plusMNNVLAllReduceFusionWorkspaceunderflashinfer.comm.trtllm_mnnvl_ar).Please double‑check that:
- These symbols actually exist under the documented modules, and
- They’re included in the corresponding
__all__so Sphinxautosummarycan import and render them without warnings.Also applies to: 97-116
flashinfer/rope.py (1)
1674-1685: LGTM! Clean public API definition.The
__all__declaration correctly exposes all public RoPE functions, making the module's API surface explicit for documentation tools.docs/api/quantization.rst (1)
15-27: LGTM! Well-structured documentation addition.The FP8 quantization section properly documents the previously missing APIs using appropriate Sphinx directives.
flashinfer/pod.py (1)
1204-1207: LGTM! Correct public API exposure.The
__all__declaration properly exposes the two POD wrapper classes, making them discoverable by documentation tools.docs/api/gemm.rst (2)
25-25: LGTM! Completes the FP8 GEMM documentation.Adding
mm_fp8to the autosummary properly documents this previously missing API.
48-75: LGTM! Well-organized new GEMM sections.The documentation properly exposes new GEMM variants:
- Blackwell GEMM for SM100 architecture
- TensorRT-LLM low-latency GEMM utilities
- CuTe-DSL GEMM kernels (conditionally available)
The structure is clear and uses appropriate Sphinx directives.
flashinfer/trtllm_low_latency_gemm.py (1)
227-229: LGTM! Proper public API exposure.The
__all__declaration correctly exports the weight preparation utility, making it accessible for documentation and user consumption.flashinfer/gemm/__init__.py (1)
22-31: LGTM! Proper handling of optional dependency.The conditional import pattern correctly handles CuTe-DSL availability, gracefully degrading when the optional dependency is not present. The
_CUTE_DSL_AVAILABLEflag provides a clear way to track availability.docs/api/attention.rst (2)
28-28: LGTM! Documents missing decode planning API.Adding
fast_decode_planto the autosummary properly exposes this utility function.
114-129: LGTM! Well-structured POD attention documentation.The new POD (Prefix-Only Decode) section properly documents the wrapper classes with appropriate Sphinx directives, aligning with the
__all__exports inflashinfer/pod.py.flashinfer/__init__.py (1)
92-100: LGTM! Robust handling of optional CuTe-DSL dependency.The conditional import pattern properly handles environments where CuTe-DSL is unavailable, allowing the package to function gracefully in both scenarios. This aligns with the availability check in
flashinfer/gemm/__init__.py.
| __all__ = [ | ||
| # Submodules | ||
| "cubin_loader", | ||
| "env", | ||
| # Activation | ||
| "gen_act_and_mul_module", | ||
| "get_act_and_mul_cu_str", | ||
| # Attention | ||
| "gen_cudnn_fmha_module", | ||
| "gen_batch_attention_module", | ||
| "gen_batch_decode_mla_module", | ||
| "gen_batch_decode_module", | ||
| "gen_batch_mla_module", | ||
| "gen_batch_prefill_module", | ||
| "gen_customize_batch_decode_module", | ||
| "gen_customize_batch_prefill_module", | ||
| "gen_customize_single_decode_module", | ||
| "gen_customize_single_prefill_module", | ||
| "gen_fmha_cutlass_sm100a_module", | ||
| "gen_batch_pod_module", | ||
| "gen_pod_module", | ||
| "gen_single_decode_module", | ||
| "gen_single_prefill_module", | ||
| "get_batch_attention_uri", | ||
| "get_batch_decode_mla_uri", | ||
| "get_batch_decode_uri", | ||
| "get_batch_mla_uri", | ||
| "get_batch_prefill_uri", | ||
| "get_pod_uri", | ||
| "get_single_decode_uri", | ||
| "get_single_prefill_uri", | ||
| "gen_trtllm_gen_fmha_module", | ||
| "get_trtllm_fmha_v2_module", | ||
| # Core | ||
| "JitSpec", | ||
| "JitSpecStatus", | ||
| "JitSpecRegistry", | ||
| "jit_spec_registry", | ||
| "build_jit_specs", | ||
| "clear_cache_dir", | ||
| "gen_jit_spec", | ||
| "MissingJITCacheError", | ||
| "sm90a_nvcc_flags", | ||
| "sm100a_nvcc_flags", | ||
| "sm100f_nvcc_flags", | ||
| "sm103a_nvcc_flags", | ||
| "sm110a_nvcc_flags", | ||
| "sm120a_nvcc_flags", | ||
| "sm121a_nvcc_flags", | ||
| "current_compilation_context", | ||
| # Cubin Loader | ||
| "setup_cubin_loader", | ||
| # Comm | ||
| "gen_comm_alltoall_module", | ||
| "gen_trtllm_mnnvl_comm_module", | ||
| "gen_trtllm_comm_module", | ||
| "gen_vllm_comm_module", | ||
| "gen_nvshmem_module", | ||
| "gen_moe_alltoall_module", | ||
| # DSv3 Optimizations | ||
| "gen_dsv3_router_gemm_module", | ||
| "gen_dsv3_fused_routing_module", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Sort __all__ alphabetically per RUF022 rule.
The __all__ list is currently organized by category, but the static analysis tool expects isort-style alphabetical sorting. While categorical organization is more readable, this should be aligned with the project's linting rules.
🔎 Proposed alphabetically sorted `__all__`
__all__ = [
+ "build_jit_specs",
+ "clear_cache_dir",
+ "current_compilation_context",
+ "cubin_loader",
+ "env",
+ "gen_act_and_mul_module",
+ "gen_batch_allreduce_module",
+ "gen_batch_attention_module",
+ "gen_batch_decode_mla_module",
+ "gen_batch_decode_module",
+ "gen_batch_mla_module",
+ "gen_batch_pod_module",
+ "gen_batch_prefill_module",
+ "gen_comm_alltoall_module",
+ "gen_cudnn_fmha_module",
+ "gen_customize_batch_decode_module",
+ "gen_customize_batch_prefill_module",
+ "gen_customize_single_decode_module",
+ "gen_customize_single_prefill_module",
+ "gen_dsv3_fused_routing_module",
+ "gen_dsv3_router_gemm_module",
+ "gen_fmha_cutlass_sm100a_module",
+ "gen_moe_alltoall_module",
+ "gen_nvshmem_module",
+ "gen_pod_module",
+ "gen_single_decode_module",
+ "gen_single_prefill_module",
+ "gen_trtllm_comm_module",
+ "gen_trtllm_gen_fmha_module",
+ "gen_trtllm_mnnvl_comm_module",
+ "gen_vllm_comm_module",
+ "get_act_and_mul_cu_str",
+ "get_batch_attention_uri",
+ "get_batch_decode_mla_uri",
+ "get_batch_decode_uri",
+ "get_batch_mla_uri",
+ "get_batch_prefill_uri",
+ "get_pod_uri",
+ "get_single_decode_uri",
+ "get_single_prefill_uri",
+ "get_trtllm_fmha_v2_module",
+ "JitSpec",
+ "JitSpecRegistry",
+ "JitSpecStatus",
+ "jit_spec_registry",
+ "MissingJITCacheError",
+ "setup_cubin_loader",
+ "sm100a_nvcc_flags",
+ "sm100f_nvcc_flags",
+ "sm103a_nvcc_flags",
+ "sm110a_nvcc_flags",
+ "sm120a_nvcc_flags",
+ "sm121a_nvcc_flags",
+ "sm90a_nvcc_flags",
]Note: Capitalized entries (JitSpec*, MissingJITCacheError) and those starting with lowercase will naturally sort together; Python's default sort treats uppercase before lowercase. If case-insensitive grouping is preferred, consider using a custom sort key.
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.8)
96-158: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🤖 Prompt for AI Agents
In flashinfer/jit/__init__.py around lines 96 to 158, the __all__ list is
grouped by category but must be alphabetically sorted to satisfy
RUF022/isort-style linting; reorder all string entries in the list into a single
alphabetical sequence (case-sensitive default sort is fine), keeping each name
as a quoted string with commas and preserving the list formatting (multiline)
and trailing comma.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
flashinfer/norm.py (1)
415-428: Fix critical syntax error in__all__definition.The code has a duplicate and malformed
__all__definition causing syntax errors that break pre-commit checks. Lines 415-417 start two separate__all__assignments, with the first one unclosed.A past review comment already suggested the correct pattern (see lines 415-423 in past comments). Please follow that suggestion.
🔎 Proposed fix
-__all__ = [ - "add_rmsnorm_fp4quant", -__all__ = [ +__all__ = [ "fused_add_rmsnorm", "gemma_fused_add_rmsnorm", "gemma_rmsnorm", "layernorm", "rmsnorm", ] if add_rmsnorm_fp4quant is not None: __all__.append("add_rmsnorm_fp4quant") if rmsnorm_fp4quant is not None: __all__.append("rmsnorm_fp4quant")
🧹 Nitpick comments (1)
flashinfer/norm.py (1)
418-423: Consider addingrmsnorm_quantandfused_add_rmsnorm_quantto__all__.Both
rmsnorm_quant(line 97) andfused_add_rmsnorm_quant(line 198) are decorated with@flashinfer_api, indicating they're part of the public API. They should be included in__all__for consistency and to support wildcard imports.🔎 Suggested addition
__all__ = [ "fused_add_rmsnorm", + "fused_add_rmsnorm_quant", "gemma_fused_add_rmsnorm", "gemma_rmsnorm", "layernorm", "rmsnorm", + "rmsnorm_quant", ]
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/norm.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/norm.py (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
add_rmsnorm_fp4quant(2219-2327)flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
rmsnorm_fp4quant(1790-1895)
🪛 GitHub Actions: pre-commit
flashinfer/norm.py
[error] 415-429: Syntax errors detected in flashinfer/norm.py (e.g., duplicate all definition, an unclosed bracket, and invalid-syntax causing mypy/ruff to fail). Please fix the Python syntax to resolve the pre-commit checks.
🪛 Ruff (0.14.8)
flashinfer/norm.py
417-417: Expected ,, found =
(invalid-syntax)
425-425: Expected else, found :
(invalid-syntax)
427-427: Expected else, found :
(invalid-syntax)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
This reverts commit c350bf3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
flashinfer/utils.py (2)
1246-1257: Consider whether internal helpers should be in__all__.The underscore-prefixed functions are typically considered private/internal by Python convention. While the comment indicates they're "used across modules," this usually means internal package usage rather than public API exposure. If these are only for inter-module communication within flashinfer, they can be imported directly by other modules without being in
__all__, keeping the public API surface smaller and more focused.However, if these are intended as a stable, semi-public API for advanced users, then including them is acceptable.
1187-1257: Optional: Consider sorting__all__for consistency.The static analysis tool suggests applying isort-style sorting to the
__all__list. The current semantic grouping by category (Enums, Exceptions, etc.) has its advantages for readability, but alphabetical sorting would improve searchability and align with common Python conventions. This is a minor stylistic preference.flashinfer/mla.py (1)
802-809: @flashinfer_api functions correctly added; clarify intent for module getter functions.The inclusion of
trtllm_batch_decode_with_kv_cache_mlaandxqa_batch_decode_with_kv_cache_mlawith@flashinfer_apidecoration addresses the prior review feedback and correctly exposes these primary public APIs.However, the three module getter functions (
get_trtllm_gen_fmha_module,get_mla_module,get_batch_mla_module) are decorated only with@functools.cache, lack docstrings, and are used exclusively within mla.py for internal module caching. Clarify whether these are intended as public exports or if they should remain internal implementation details.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/mla.py(1 hunks)flashinfer/utils.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.14.8)
flashinfer/utils.py
1187-1257: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer/mla.py
802-809: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/utils.py (1)
1187-1257: Comprehensive__all__export list successfully addresses PR objective.The addition of this explicit
__all__list makes the public API surface clear and helps documentation tools properly index the intended exports. The organization with category comments (Enums, Exceptions, Classes, etc.) enhances maintainability. This addresses the concerns raised in previous reviews about missing public symbols.
|
/bot run |
|
[SUCCESS] Pipeline #40597907: 12/20 passed |
📌 Description
We currently don't define all in our Python modules. Adding it will make public APIs explicit and ensure documentation tools index all intended exports.
This PR also adds missing APIs from documentation index.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Documentation
New Features
Chores
✏️ Tip: You can customize this high-level summary in your review settings.