-
Notifications
You must be signed in to change notification settings - Fork 621
refactor: pull trtllm-gen batch-gemm/gemm headers from artifactory; update tma descriptor shape init #2235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis PR updates TensorRT-LLM GEMM integration by introducing valid dimension fields to problem dimensions, migrating header files from repository inclusion to dynamic download, updating artifact paths and checksums, and extending the CUBIN loader with header file management capabilities. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request streamlines the management of TRTLLM-Gen headers by transitioning from bundled files to fetching them dynamically from Artifactory. It also includes crucial updates to how TMA problem dimensions are initialized in the CUDA GEMM runners, ensuring compatibility and correctness with the updated header fetching mechanism. The changes enhance the flexibility and maintainability of the artifact management system. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the build process to download trt-llm headers from an artifactory during JIT compilation, which is a commendable improvement for dependency management. The changes also update TMA descriptor initializations, likely to align with a newer API.
My main feedback revolves around improving maintainability. There's significant code duplication for downloading headers and creating symlinks across flashinfer/jit/fused_moe.py and flashinfer/jit/gemm/core.py, which should be refactored into a shared helper function. Additionally, the lists of header files are hardcoded, creating a maintenance burden; a more dynamic approach would be better. There are also opportunities to reduce duplication between get_available_header_files and get_available_cubin_files and to clean up commented-out code.
| header_files = [ | ||
| "BatchedGemmEnums.h", | ||
| "BatchedGemmInterface.h", | ||
| "BatchedGemmOptions.h", | ||
| "Enums.h", | ||
| "GemmGatedActOptions.h", | ||
| "GemmOptions.h", | ||
| "KernelParams.h", | ||
| "KernelParamsDecl.h", | ||
| "KernelTraits.h", | ||
| "TmaDescriptor.h", | ||
| "trtllm/gen/CommonUtils.h", | ||
| "trtllm/gen/CudaArchDecl.h", | ||
| "trtllm/gen/CudaKernelLauncher.h", | ||
| "trtllm/gen/DtypeDecl.h", | ||
| "trtllm/gen/MmaDecl.h", | ||
| "trtllm/gen/SfLayoutDecl.h", | ||
| ] | ||
|
|
||
| header_path = f"{include_path}/trtllmGen_bmm_export" | ||
| for file in header_files: | ||
| uri_path = f"{header_path}/{file}" | ||
| file_hash = get_meta_hash(checksum, file) | ||
| file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_bmm_export" / file | ||
| get_file(uri_path, file_hash, file_path) | ||
| # Create directory flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export pointing to trtllmGen_bmm_export | ||
|
|
||
| symlink_parent = str( | ||
| jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/batched_gemm" | ||
| ) | ||
| make_symlink( | ||
| "../../../trtllmGen_bmm_export", symlink_parent, "trtllmGen_bmm_export" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for downloading headers and creating symlinks is duplicated in flashinfer/jit/gemm/core.py in gen_trtllm_gen_gemm_module and gen_trtllm_low_latency_gemm_module. This introduces a significant maintenance burden.
To improve the code quality, please refactor this logic into a single, reusable helper function. This function could be placed in flashinfer/jit/cubin_loader.py and accept parameters such as the list of headers, paths, and checksums.
| def get_available_header_files( | ||
| source: str, retries: int = 3, delay: int = 5, timeout: int = 10 | ||
| ) -> tuple[str, ...]: | ||
| """ | ||
| Recursively navigates through child directories (e.g., include/) and finds | ||
| all *.h header files, returning them as a tuple of relative paths. | ||
| """ | ||
| result: list[str] = [] | ||
|
|
||
| def fetch_directory(url: str, prefix: str = "") -> None: | ||
| for attempt in range(1, retries + 1): | ||
| try: | ||
| response = requests.get(url, timeout=timeout) | ||
| response.raise_for_status() | ||
|
|
||
| # Find all .h header files in this directory | ||
| header_hrefs = re.findall(r'<a href="([^"]+\.h)">', response.text) | ||
| for h in header_hrefs: | ||
| result.append(prefix + h if prefix else h) | ||
|
|
||
| # Find all subdirectories (links ending with /) | ||
| dir_hrefs = re.findall(r'<a href="([^"]+/)">', response.text) | ||
| for d in dir_hrefs: | ||
| # Skip parent directory links | ||
| if d == "../" or d.startswith(".."): | ||
| continue | ||
| subdir_url = safe_urljoin(url, d) | ||
| subdir_prefix = prefix + d if prefix else d | ||
| fetch_directory(subdir_url, subdir_prefix) | ||
|
|
||
| return # Success, exit retry loop | ||
|
|
||
| except requests.exceptions.RequestException as e: | ||
| logger.warning( | ||
| f"Fetching available header files {url}: attempt {attempt} failed: {e}" | ||
| ) | ||
|
|
||
| if attempt < retries: | ||
| logger.info(f"Retrying in {delay} seconds...") | ||
| time.sleep(delay) | ||
|
|
||
| logger.error(f"Max retries reached for {url}. Fetch failed.") | ||
|
|
||
| fetch_directory(source) | ||
| logger.info(f"result: {result}") | ||
| return tuple(result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new function get_available_header_files has very similar logic to the existing get_available_cubin_files function. To improve maintainability and reduce code duplication, consider refactoring them into a single, more generic function. This new function could accept the file extension (e.g., .h or .cubin) as a parameter.
| TRTLLM_GEN_BMM: str = ( | ||
| "ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841" | ||
| "02546c924085adc5df7dc0a211cacc7ec3d3e01c/batched_gemm-0d275a2-9936841" | ||
| ) | ||
| # TRTLLM_GEN_BMM: str = ( | ||
| # "ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841" | ||
| # ) | ||
| TRTLLM_GEN_GEMM: str = ( | ||
| "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" | ||
| "02546c924085adc5df7dc0a211cacc7ec3d3e01c/gemm-0d275a2-30f1102" | ||
| ) | ||
| # TRTLLM_GEN_GEMM: str = ( | ||
| # "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" | ||
| # ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| TRTLLM_GEN_BMM: str = ( | ||
| "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc" | ||
| "680167f34b532d493d3ed71da0a1640054cf1cb0a80cfca20e7d797dbd093a90" | ||
| ) | ||
| # TRTLLM_GEN_BMM: str = ( | ||
| # "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc" | ||
| # ) | ||
| DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" | ||
| TRTLLM_GEN_GEMM: str = ( | ||
| "15cb8c85dfb5eddd4f121d64cb5a718321fb55b85aa19df10ddc1329d4a726b9" | ||
| "014473f273a4dd248b5608e813f0fe468f05c686093577abd23f7a64afd77a60" | ||
| ) | ||
| # TRTLLM_GEN_GEMM: str = ( | ||
| # "15cb8c85dfb5eddd4f121d64cb5a718321fb55b85aa19df10ddc1329d4a726b9" | ||
| # ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the commented-out lines containing old hashes to keep the code clean.
TRTLLM_GEN_BMM: str = (
"680167f34b532d493d3ed71da0a1640054cf1cb0a80cfca20e7d797dbd093a90"
)
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
TRTLLM_GEN_GEMM: str = (
"014473f273a4dd248b5608e813f0fe468f05c686093577abd23f7a64afd77a60"
)| Otherwise, download the file from {uri_path} and write to {file_path}. | ||
| """ | ||
|
|
||
| file = load_cubin(file_path, sha256) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| header_files = [ | ||
| "BatchedGemmEnums.h", | ||
| "BatchedGemmInterface.h", | ||
| "BatchedGemmOptions.h", | ||
| "Enums.h", | ||
| "GemmGatedActOptions.h", | ||
| "GemmOptions.h", | ||
| "KernelParams.h", | ||
| "KernelParamsDecl.h", | ||
| "KernelTraits.h", | ||
| "TmaDescriptor.h", | ||
| "trtllm/gen/CommonUtils.h", | ||
| "trtllm/gen/CudaArchDecl.h", | ||
| "trtllm/gen/CudaKernelLauncher.h", | ||
| "trtllm/gen/DtypeDecl.h", | ||
| "trtllm/gen/MmaDecl.h", | ||
| "trtllm/gen/SfLayoutDecl.h", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list of header_files is hardcoded here and in other similar functions. This makes it difficult to maintain when the upstream trt-llm dependency adds or removes headers.
Consider making this more dynamic. You could, for example, use the new get_available_header_files function to fetch the list of headers from the artifactory directly, rather than hardcoding them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/jit/gemm/core.py (1)
565-607: Consider extracting shared header download logic and verify include path consistency.The header file download logic (lines 566-592) is duplicated between
gen_trtllm_gen_gemm_moduleandgen_trtllm_low_latency_gemm_module. Additionally, there's an inconsistency inextra_include_paths:gen_trtllm_gen_gemm_moduleincludes bothFLASHINFER_CUBIN_DIRandFLASHINFER_CUBIN_DIR / include_path(lines 424-427), whilegen_trtllm_low_latency_gemm_moduleonly includesFLASHINFER_CUBIN_DIR / include_path(line 607).
- Extract the header download logic into a shared helper function
- Verify whether the include path difference is intentional or an oversight - if the symlink at
flashinfer/trtllm/gemm/trtllmGen_gemm_exportis needed, thenFLASHINFER_CUBIN_DIRshould be included here as wellExample refactor:
def download_trtllm_gemm_headers(artifact_path: str, checksum_hash: str): """Download and cache TRTLLM GEMM export headers.""" include_path = f"{artifact_path}/include" checksum_path = f"{artifact_path}/checksums.txt" checksum = get_cubin(checksum_path, checksum_hash) assert checksum, f"Failed to get checksums.txt from {checksum_path}" meta_hash = get_meta_hash(checksum) header_name = "flashinferMetaInfo" metainfo = get_cubin(f"{include_path}/{header_name}.h", meta_hash) assert metainfo, f"{header_name}.h not found" header_files = [ "GemmInterface.h", "GemmOptions.h", # ... rest of the list ] header_path = f"{include_path}/trtllmGen_gemm_export" for file in header_files: uri_path = f"{header_path}/{file}" file_hash = get_meta_hash(checksum, file) file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export" / file get_file(uri_path, file_hash, str(file_path)) symlink_parent = str(jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/gemm") make_symlink("../../../trtllmGen_gemm_export", symlink_parent, "trtllmGen_gemm_export") return include_path
🧹 Nitpick comments (2)
flashinfer/jit/cubin_loader.py (1)
139-151: Clarify the inline comment.The comment mentions "case-insensitive for the 'I' in Infer" but the implementation performs case-insensitive matching on the entire filename using
.lower().endswith(target_file.lower()). This may confuse future maintainers.Consider updating the comment to:
- # Match specifically flashinferMetaInfo.h (case-insensitive for the 'I' in Infer) + # Match filename ending with target_file (case-insensitive)flashinfer/artifacts.py (1)
82-127: Consider reducing log verbosity.Line 126 logs the entire list of header files at INFO level, which could be verbose for directories with many headers. Consider logging the count instead or moving this to DEBUG level.
- logger.info(f"result: {result}") + logger.info(f"Found {len(result)} header files") + logger.debug(f"Header files: {result}")
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (11)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaKernelLauncher.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/MmaDecl.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.his excluded by!**/gen/**
📒 Files selected for processing (24)
csrc/trtllm_batched_gemm_runner.cu(2 hunks)csrc/trtllm_gemm_runner.cu(3 hunks)csrc/trtllm_low_latency_gemm_runner.cu(1 hunks)flashinfer/artifacts.py(5 hunks)flashinfer/jit/cubin_loader.py(2 hunks)flashinfer/jit/fused_moe.py(3 hunks)flashinfer/jit/gemm/core.py(4 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h(0 hunks)
💤 Files with no reviewable changes (17)
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/trtllm_gemm_runner.cucsrc/trtllm_batched_gemm_runner.cu
🧬 Code graph analysis (3)
flashinfer/jit/fused_moe.py (1)
flashinfer/jit/cubin_loader.py (4)
get_cubin(227-246)get_meta_hash(139-151)make_symlink(195-203)get_file(206-224)
flashinfer/jit/gemm/core.py (1)
flashinfer/jit/cubin_loader.py (3)
get_meta_hash(139-151)make_symlink(195-203)get_file(206-224)
flashinfer/artifacts.py (1)
flashinfer/jit/cubin_loader.py (1)
safe_urljoin(38-42)
🪛 Ruff (0.14.8)
flashinfer/jit/cubin_loader.py
209-209: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
flashinfer/artifacts.py
112-112: Consider moving this statement to an else block
(TRY300)
🔇 Additional comments (15)
csrc/trtllm_low_latency_gemm_runner.cu (1)
52-54: LGTM!The initialization of the "valid" dimension fields mirrors the primary dimension fields, which aligns with the PR's objective to update TMA descriptor shape initialization.
csrc/trtllm_gemm_runner.cu (3)
122-124: LGTM!The valid dimension fields are correctly initialized after the primary dimensions, maintaining consistency with the transpose logic already applied to mM and mN.
145-147: LGTM!The initialization pattern is consistent with the getWorkspaceSizeInBytes method.
196-198: LGTM!The initialization pattern is uniformly applied across all three methods (getWorkspaceSizeInBytes, run, and getValidTactics).
csrc/trtllm_batched_gemm_runner.cu (3)
149-151: LGTM!The valid dimension fields are correctly initialized in the workspace sizing method.
343-345: LGTM!The valid dimension fields are correctly initialized in the config validation method.
448-450: LGTM!The valid dimension fields are correctly initialized in the config validation method.
flashinfer/artifacts.py (4)
139-150: LGTM!The artifact paths have been updated with new hashes, and the old values are preserved as comments for reference or rollback. This aligns with the PR objective to pull trtllm-gen headers from Artifactory.
166-178: LGTM!The checksums have been updated to match the new artifact paths. Good practice to preserve the old values as comments.
245-247: LGTM!The header files are now included in the artifact listing, mirroring the pattern used for cubin files. This enables header files to be downloaded and verified alongside cubin binaries.
256-256: LGTM!Adding explicit type annotation improves code clarity.
flashinfer/jit/gemm/core.py (2)
33-38: LGTM!The new imports support the header file management functionality added in this PR.
424-427: LGTM!The include paths are extended to support both the root cubin directory (for the symlink at
flashinfer/trtllm/gemm/trtllmGen_gemm_export) and the artifact-specific include directory. This enables the compilation to locate the downloaded headers.flashinfer/jit/fused_moe.py (2)
29-29: LGTM!The new imports support the header file management functionality added in this PR.
298-303: LGTM!The include paths are extended to support both the root cubin directory and the artifact-specific include directory, consistent with the pattern in
gen_trtllm_gen_gemm_module.
| gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; | ||
| gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; | ||
| gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove duplicate initialization.
The valid dimension fields (mValidM, mValidN, mValidK) are initialized twice in this method with identical values. Lines 212-214 correctly initialize these fields immediately after mM, mN, mK are set. The second initialization at lines 251-253 is redundant since the dimension values haven't changed between these two points.
Apply this diff to remove the duplicate initialization:
gemmData.mOutputBuffers.mPtrC = c;
gemmData.mOutputBuffers.mPtrSfC = outSfC;
int32_t multiProcessorCount;
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
- gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
- gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
- gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
-
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));Also applies to: 251-253
🤖 Prompt for AI Agents
In csrc/trtllm_batched_gemm_runner.cu around lines 212-214 and again at 251-253,
the mProblemDimensions.mValidM/mValidN/mValidK fields are initialized twice with
the same values; remove the redundant initialization at lines 251-253 so the
valid dimension fields are only set once immediately after mM/mN/mK are
assigned.
| def make_symlink(src: str, symlink_parent: str, symlink_name: str) -> None: | ||
| """ | ||
| Create a symlink from {src} to {symlink_parent}/{symlink_name}. | ||
| If the symlink does not exist, create the parent directory and create the symlink. | ||
| """ | ||
| symlink_path = safe_urljoin(symlink_parent, symlink_name) | ||
| if not os.path.exists(symlink_path): | ||
| os.makedirs(symlink_parent, exist_ok=True) | ||
| os.symlink(src, symlink_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle broken symlinks correctly.
os.path.exists() returns False for broken symlinks, which would cause the function to attempt creating a symlink that already exists (resulting in an FileExistsError). Use os.path.lexists() or check with os.path.islink() to detect existing symlinks regardless of whether their target exists.
Apply this diff:
- if not os.path.exists(symlink_path):
+ if not os.path.lexists(symlink_path):
os.makedirs(symlink_parent, exist_ok=True)
os.symlink(src, symlink_path)🤖 Prompt for AI Agents
In flashinfer/jit/cubin_loader.py around lines 195 to 203, the current check
uses os.path.exists(symlink_path) which returns False for broken symlinks and
can cause attempts to recreate an existing symlink; change the existence check
to use os.path.lexists(symlink_path) (or os.path.islink(symlink_path)) so
existing symlinks are detected regardless of target validity, and only call
os.makedirs and os.symlink when the lexists/islink check indicates the path
truly does not exist to avoid FileExistsError.
Use filesystem path operations instead of URL joining.
safe_urljoin is designed for URLs and may not handle filesystem paths correctly across platforms. For filesystem operations, prefer os.path.join or pathlib.Path.
Apply this diff:
- symlink_path = safe_urljoin(symlink_parent, symlink_name)
+ symlink_path = os.path.join(symlink_parent, symlink_name)🤖 Prompt for AI Agents
In flashinfer/jit/cubin_loader.py around lines 195 to 203, the code uses
safe_urljoin to build a filesystem symlink path which is for URLs and can
mis-handle platform paths; replace safe_urljoin with a filesystem-aware join
(prefer pathlib.Path or os.path.join) to construct symlink_path (e.g.,
Path(symlink_parent) / symlink_name), ensure the parent directory exists
(Path(symlink_parent).mkdir(parents=True, exist_ok=True) or os.makedirs), and
then create the symlink with os.symlink(src, str(symlink_path)); keep the
existing existence check (Path(symlink_path).exists() or os.path.exists) to
avoid overwriting.
| def get_file( | ||
| uri_path: str, | ||
| sha256: str, | ||
| file_path: str = None, | ||
| session=None, | ||
| ) -> bytes: | ||
| """ | ||
| Load a file from local cache directory {file_path}, ensure that the sha256 signature matches. | ||
| Otherwise, download the file from {uri_path} and write to {file_path}. | ||
| """ | ||
|
|
||
| file = load_cubin(file_path, sha256) | ||
| if file: | ||
| return file | ||
| os.makedirs(os.path.dirname(file_path), exist_ok=True) | ||
| uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, uri_path) | ||
| logger.info(f"Fetching file from {uri}") | ||
| download_file(uri, file_path, session=session) | ||
| return load_cubin(file_path, sha256) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix type annotation and add validation for required parameter.
The file_path parameter is annotated as str = None but is used without null checks (lines 217, 220), which will cause an AttributeError if None is passed. Based on the function's usage in flashinfer/jit/gemm/core.py and flashinfer/jit/fused_moe.py, this parameter appears to be required.
Apply this diff:
def get_file(
uri_path: str,
sha256: str,
- file_path: str = None,
+ file_path: str,
session=None,
) -> bytes:🧰 Tools
🪛 Ruff (0.14.8)
209-209: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
🤖 Prompt for AI Agents
In flashinfer/jit/cubin_loader.py around lines 206 to 224, the file_path
parameter is annotated as optional (str = None) but used unguarded
(os.path.dirname, load_cubin), which will raise AttributeError if None; change
the signature to require file_path: str (remove the = None default) and add an
explicit validation at the start of the function that raises a ValueError with a
clear message if file_path is falsy (None or empty). Ensure callers still pass a
valid path (update call sites if necessary) so the function never receives None.
| header_files = [ | ||
| "BatchedGemmEnums.h", | ||
| "BatchedGemmInterface.h", | ||
| "BatchedGemmOptions.h", | ||
| "Enums.h", | ||
| "GemmGatedActOptions.h", | ||
| "GemmOptions.h", | ||
| "KernelParams.h", | ||
| "KernelParamsDecl.h", | ||
| "KernelTraits.h", | ||
| "TmaDescriptor.h", | ||
| "trtllm/gen/CommonUtils.h", | ||
| "trtllm/gen/CudaArchDecl.h", | ||
| "trtllm/gen/CudaKernelLauncher.h", | ||
| "trtllm/gen/DtypeDecl.h", | ||
| "trtllm/gen/MmaDecl.h", | ||
| "trtllm/gen/SfLayoutDecl.h", | ||
| ] | ||
|
|
||
| header_path = f"{include_path}/trtllmGen_bmm_export" | ||
| for file in header_files: | ||
| uri_path = f"{header_path}/{file}" | ||
| file_hash = get_meta_hash(checksum, file) | ||
| file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_bmm_export" / file | ||
| get_file(uri_path, file_hash, file_path) | ||
| # Create directory flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export pointing to trtllmGen_bmm_export | ||
|
|
||
| symlink_parent = str( | ||
| jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/batched_gemm" | ||
| ) | ||
| make_symlink( | ||
| "../../../trtllmGen_bmm_export", symlink_parent, "trtllmGen_bmm_export" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for header file downloads.
Similar to the issue in flashinfer/jit/gemm/core.py, the get_file function returns empty bytes on failure, but there's no verification that downloads succeeded. This could lead to compilation failures with unclear error messages.
Consider adding validation:
for file in header_files:
uri_path = f"{header_path}/{file}"
file_hash = get_meta_hash(checksum, file)
file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_bmm_export" / file
result = get_file(uri_path, file_hash, str(file_path))
if not result:
raise RuntimeError(f"Failed to download header file: {file}")🤖 Prompt for AI Agents
In flashinfer/jit/fused_moe.py around lines 233 to 265, the code calls get_file
for each header but does not validate the returned bytes, so failed downloads
(empty bytes) will go unnoticed; update the loop to capture get_file's return
value, verify it's non-empty, and raise a RuntimeError (including the filename
and preferably uri_path and file_hash) if the download failed; ensure get_file
is passed the correct file_path type (str if required by get_file) and stop
processing further files on the first failure.
|
|
||
| header_files = [ | ||
| "GemmInterface.h", | ||
| "GemmOptions.h", | ||
| "Enums.h", | ||
| "KernelTraits.h", | ||
| "KernelParams.h", | ||
| "KernelParamsDecl.h", | ||
| "TmaDescriptor.h", | ||
| "trtllm/gen/CommonUtils.h", | ||
| "trtllm/gen/CudaKernelLauncher.h", | ||
| "trtllm/gen/DtypeDecl.h", | ||
| "trtllm/gen/MmaDecl.h", | ||
| "trtllm/gen/SfLayoutDecl.h", | ||
| "trtllm/gen/CudaArchDecl.h", | ||
| ] | ||
|
|
||
| header_path = f"{include_path}/trtllmGen_gemm_export" | ||
| for file in header_files: | ||
| uri_path = f"{header_path}/{file}" | ||
| file_hash = get_meta_hash(checksum, file) | ||
| file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export" / file | ||
| get_file(uri_path, file_hash, file_path) | ||
| # Create directory flashinfer/trtllm/gemm/trtllmGen_gemm_export pointing to trtllmGen_gemm_export | ||
| symlink_parent = str(jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/gemm") | ||
| make_symlink( | ||
| "../../../trtllmGen_gemm_export", symlink_parent, "trtllmGen_gemm_export" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for header file downloads.
The get_file function returns empty bytes on failure (from load_cubin at line 224 of cubin_loader.py), but there's no verification that the download succeeded. This could lead to compilation failures with cryptic error messages.
Consider adding validation after the download loop:
header_path = f"{include_path}/trtllmGen_gemm_export"
downloaded_files = []
for file in header_files:
uri_path = f"{header_path}/{file}"
file_hash = get_meta_hash(checksum, file)
file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export" / file
result = get_file(uri_path, file_hash, str(file_path))
if not result:
raise RuntimeError(f"Failed to download header file: {file}")
downloaded_files.append(file)🤖 Prompt for AI Agents
In flashinfer/jit/gemm/core.py around lines 382 to 409, the loop that calls
get_file for each header does not verify the download succeeded (get_file can
return empty bytes), so add immediate validation after each get_file call:
capture the return value, if it is falsy/empty raise a RuntimeError naming the
missing header (e.g. f"Failed to download header file: {file}"), and optionally
collect successful filenames into a list; ensure you pass the correct path/type
to get_file as required by its signature and fail fast to avoid later cryptic
compilation errors.
|
@aleozlx can you help reviewing this PR? |
📌 Description
[todo] will update artifactory hash once pipeline completes
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Refactor
✏️ Tip: You can customize this high-level summary in your review settings.