Skip to content

Conversation

@jimmyzho
Copy link
Contributor

@jimmyzho jimmyzho commented Dec 17, 2025

📌 Description

[todo] will update artifactory hash once pipeline completes

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Improved kernel dimension validation with added valid dimension tracking for batched GEMM and GEMM kernels.
  • Refactor

    • Enhanced artifact management system with improved header file discovery, downloading, and caching.
    • Reorganized internal build system components for better modularity and maintainability.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 17, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

This PR updates TensorRT-LLM GEMM integration by introducing valid dimension fields to problem dimensions, migrating header files from repository inclusion to dynamic download, updating artifact paths and checksums, and extending the CUBIN loader with header file management capabilities.

Changes

Cohort / File(s) Summary
CUDA Runner Dimension Fields
csrc/trtllm_batched_gemm_runner.cu, csrc/trtllm_gemm_runner.cu, csrc/trtllm_low_latency_gemm_runner.cu
Added mValidM, mValidN, mValidK field assignments to ProblemDimensions in getWorkspaceSizeInBytes, run, and validation paths, mirroring primary M/N/K dimension propagation.
Artifact Management
flashinfer/artifacts.py
Updated TRTLLM_GEN_BMM and TRTLLM_GEN_GEMM artifact paths and checksums; added get_available_header_files() function with retry logic for recursive .h file collection; enhanced get_subdir_file_list() to yield header files; updated download_artifacts() type annotations.
CUBIN Loader Extensions
flashinfer/jit/cubin_loader.py
Expanded get_meta_hash() signature with optional target_file parameter; added make_symlink() for creating symlinks with parent directory creation; added get_file() for cached file loading with SHA256 verification and download fallback.
JIT Module Header Integration
flashinfer/jit/fused_moe.py, flashinfer/jit/gemm/core.py
Extended imports to include make_symlink and get_file; added logic to download trtllmGen_bmm_export and trtllmGen_gemm_export header files; created symlinks to export directories; expanded include paths to reference FLASHINFER_CUBIN_DIR.
Deleted Header Files
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/*, include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/*
Removed BatchedGemmEnums.h, BatchedGemmInterface.h, BatchedGemmOptions.h, Enums.h, GemmGatedActOptions.h, GemmOptions.h, GemmInterface.h, KernelParams.h, KernelParamsDecl.h, KernelTraits.h, TmaDescriptor.h (public interface definitions, configuration structs, validation helpers, TMA descriptor builders, and memory trait utilities).

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Areas requiring extra attention:

  • Heterogeneous changes across multiple domains: Updates span CUDA kernel runners, Python artifact management, JIT module generation, and header file deletions, each requiring separate reasoning.
  • Header file migration strategy: Verify that all deleted include-tree headers are properly downloaded and linked in the new artifact/JIT flow, especially the interdependencies between batched_gemm and gemm export files.
  • Artifact path and checksum updates: Confirm TRTLLM_GEN_BMM and TRTLLM_GEN_GEMM paths/checksums are correct and match the new download sources.
  • Valid dimension field initialization: Ensure mValidM/mValidN/mValidK propagation is consistent across all code paths (getWorkspaceSizeInBytes, run, validation) and doesn't introduce off-by-one or uninitialized-value issues.
  • Symlink and include path resolution: Validate that symlink creation in JIT modules correctly points to exported directories and that include paths for compilation resolve properly (both FLASHINFER_CUBIN_DIR and include_path variants).

Possibly related PRs

Suggested reviewers

  • aleozlx
  • djmmoss
  • yongwww
  • nvmbreughe
  • bkryu
  • cyx-6

Poem

🐰 Headers now travel from clouds up high,
No longer stored in repos nearby,
Valid dimensions keep bounds in place,
Symlinks point to the right export space,
A rabbit's refactor, efficient and bright! ✨

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is largely incomplete and vague. It contains only '[todo] will update artifactory hash once pipeline completes' as meaningful content, with unchecked checklists and no detailed explanation of changes. Add a comprehensive description explaining what headers are being pulled, why they're needed, and clarify the TMA descriptor shape initialization changes.
Docstring Coverage ⚠️ Warning Docstring coverage is 35.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main changes: pulling trtllm-gen batch-gemm/gemm headers from artifactory and updating TMA descriptor shape initialization, matching the file modifications shown in the raw summary.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@jimmyzho jimmyzho changed the title chore: pull trtllm-gen headers from artifactory; update tma descriptor shape init chore: pull trtllm-gen batch-gemm/gemm headers from artifactory; update tma descriptor shape init Dec 17, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request streamlines the management of TRTLLM-Gen headers by transitioning from bundled files to fetching them dynamically from Artifactory. It also includes crucial updates to how TMA problem dimensions are initialized in the CUDA GEMM runners, ensuring compatibility and correctness with the updated header fetching mechanism. The changes enhance the flexibility and maintainability of the artifact management system.

Highlights

  • Artifactory Integration for TRTLLM-Gen Headers: The build system now pulls TRTLLM-Gen header files directly from Artifactory, eliminating the need to bundle them within the repository. This is managed by new Python functions that fetch and cache these headers.
  • Updated TMA Descriptor Initialization: The trtllm_batched_gemm_runner.cu, trtllm_gemm_runner.cu, and trtllm_low_latency_gemm_runner.cu files have been updated to explicitly initialize mValidM, mValidN, and mValidK fields within the gemmData.mProblemDimensions structure, ensuring correct TMA descriptor setup.
  • Refined Artifact Management Logic: New utility functions (make_symlink, get_file) have been added to cubin_loader.py to improve local caching and symbolic linking of downloaded artifacts. The get_meta_hash function was also made more specific to flashinferMetaInfo.h.
  • Updated TRTLLM-Gen Artifact Hashes: The Artifactory hashes for TRTLLM_GEN_BMM and TRTLLM_GEN_GEMM artifacts have been updated in flashinfer/artifacts.py to reflect the latest versions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the build process to download trt-llm headers from an artifactory during JIT compilation, which is a commendable improvement for dependency management. The changes also update TMA descriptor initializations, likely to align with a newer API.

My main feedback revolves around improving maintainability. There's significant code duplication for downloading headers and creating symlinks across flashinfer/jit/fused_moe.py and flashinfer/jit/gemm/core.py, which should be refactored into a shared helper function. Additionally, the lists of header files are hardcoded, creating a maintenance burden; a more dynamic approach would be better. There are also opportunities to reduce duplication between get_available_header_files and get_available_cubin_files and to clean up commented-out code.

Comment on lines +233 to +265
header_files = [
"BatchedGemmEnums.h",
"BatchedGemmInterface.h",
"BatchedGemmOptions.h",
"Enums.h",
"GemmGatedActOptions.h",
"GemmOptions.h",
"KernelParams.h",
"KernelParamsDecl.h",
"KernelTraits.h",
"TmaDescriptor.h",
"trtllm/gen/CommonUtils.h",
"trtllm/gen/CudaArchDecl.h",
"trtllm/gen/CudaKernelLauncher.h",
"trtllm/gen/DtypeDecl.h",
"trtllm/gen/MmaDecl.h",
"trtllm/gen/SfLayoutDecl.h",
]

header_path = f"{include_path}/trtllmGen_bmm_export"
for file in header_files:
uri_path = f"{header_path}/{file}"
file_hash = get_meta_hash(checksum, file)
file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_bmm_export" / file
get_file(uri_path, file_hash, file_path)
# Create directory flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export pointing to trtllmGen_bmm_export

symlink_parent = str(
jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/batched_gemm"
)
make_symlink(
"../../../trtllmGen_bmm_export", symlink_parent, "trtllmGen_bmm_export"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for downloading headers and creating symlinks is duplicated in flashinfer/jit/gemm/core.py in gen_trtllm_gen_gemm_module and gen_trtllm_low_latency_gemm_module. This introduces a significant maintenance burden.

To improve the code quality, please refactor this logic into a single, reusable helper function. This function could be placed in flashinfer/jit/cubin_loader.py and accept parameters such as the list of headers, paths, and checksums.

Comment on lines +82 to +127
def get_available_header_files(
source: str, retries: int = 3, delay: int = 5, timeout: int = 10
) -> tuple[str, ...]:
"""
Recursively navigates through child directories (e.g., include/) and finds
all *.h header files, returning them as a tuple of relative paths.
"""
result: list[str] = []

def fetch_directory(url: str, prefix: str = "") -> None:
for attempt in range(1, retries + 1):
try:
response = requests.get(url, timeout=timeout)
response.raise_for_status()

# Find all .h header files in this directory
header_hrefs = re.findall(r'<a href="([^"]+\.h)">', response.text)
for h in header_hrefs:
result.append(prefix + h if prefix else h)

# Find all subdirectories (links ending with /)
dir_hrefs = re.findall(r'<a href="([^"]+/)">', response.text)
for d in dir_hrefs:
# Skip parent directory links
if d == "../" or d.startswith(".."):
continue
subdir_url = safe_urljoin(url, d)
subdir_prefix = prefix + d if prefix else d
fetch_directory(subdir_url, subdir_prefix)

return # Success, exit retry loop

except requests.exceptions.RequestException as e:
logger.warning(
f"Fetching available header files {url}: attempt {attempt} failed: {e}"
)

if attempt < retries:
logger.info(f"Retrying in {delay} seconds...")
time.sleep(delay)

logger.error(f"Max retries reached for {url}. Fetch failed.")

fetch_directory(source)
logger.info(f"result: {result}")
return tuple(result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new function get_available_header_files has very similar logic to the existing get_available_cubin_files function. To improve maintainability and reduce code duplication, consider refactoring them into a single, more generic function. This new function could accept the file extension (e.g., .h or .cubin) as a parameter.

Comment on lines 139 to +150
TRTLLM_GEN_BMM: str = (
"ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841"
"02546c924085adc5df7dc0a211cacc7ec3d3e01c/batched_gemm-0d275a2-9936841"
)
# TRTLLM_GEN_BMM: str = (
# "ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841"
# )
TRTLLM_GEN_GEMM: str = (
"1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
"02546c924085adc5df7dc0a211cacc7ec3d3e01c/gemm-0d275a2-30f1102"
)
# TRTLLM_GEN_GEMM: str = (
# "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please remove the commented-out lines containing old hashes to keep the code clean.

    TRTLLM_GEN_BMM: str = (
        "02546c924085adc5df7dc0a211cacc7ec3d3e01c/batched_gemm-0d275a2-9936841"
    )
    TRTLLM_GEN_GEMM: str = (
        "02546c924085adc5df7dc0a211cacc7ec3d3e01c/gemm-0d275a2-30f1102"
    )

Comment on lines 166 to +178
TRTLLM_GEN_BMM: str = (
"b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc"
"680167f34b532d493d3ed71da0a1640054cf1cb0a80cfca20e7d797dbd093a90"
)
# TRTLLM_GEN_BMM: str = (
# "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc"
# )
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
TRTLLM_GEN_GEMM: str = (
"15cb8c85dfb5eddd4f121d64cb5a718321fb55b85aa19df10ddc1329d4a726b9"
"014473f273a4dd248b5608e813f0fe468f05c686093577abd23f7a64afd77a60"
)
# TRTLLM_GEN_GEMM: str = (
# "15cb8c85dfb5eddd4f121d64cb5a718321fb55b85aa19df10ddc1329d4a726b9"
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please remove the commented-out lines containing old hashes to keep the code clean.

    TRTLLM_GEN_BMM: str = (
        "680167f34b532d493d3ed71da0a1640054cf1cb0a80cfca20e7d797dbd093a90"
    )
    DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
    TRTLLM_GEN_GEMM: str = (
        "014473f273a4dd248b5608e813f0fe468f05c686093577abd23f7a64afd77a60"
    )

Otherwise, download the file from {uri_path} and write to {file_path}.
"""

file = load_cubin(file_path, sha256)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function load_cubin is now being used to load generic files, not just cubins. This makes the function name misleading. While load_cubin is not modified in this PR, it would be good to rename it to something more generic like _load_file_with_checksum in a follow-up PR to improve code clarity.

Comment on lines +233 to +250
header_files = [
"BatchedGemmEnums.h",
"BatchedGemmInterface.h",
"BatchedGemmOptions.h",
"Enums.h",
"GemmGatedActOptions.h",
"GemmOptions.h",
"KernelParams.h",
"KernelParamsDecl.h",
"KernelTraits.h",
"TmaDescriptor.h",
"trtllm/gen/CommonUtils.h",
"trtllm/gen/CudaArchDecl.h",
"trtllm/gen/CudaKernelLauncher.h",
"trtllm/gen/DtypeDecl.h",
"trtllm/gen/MmaDecl.h",
"trtllm/gen/SfLayoutDecl.h",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list of header_files is hardcoded here and in other similar functions. This makes it difficult to maintain when the upstream trt-llm dependency adds or removes headers.

Consider making this more dynamic. You could, for example, use the new get_available_header_files function to fetch the list of headers from the artifactory directly, rather than hardcoding them.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/jit/gemm/core.py (1)

565-607: Consider extracting shared header download logic and verify include path consistency.

The header file download logic (lines 566-592) is duplicated between gen_trtllm_gen_gemm_module and gen_trtllm_low_latency_gemm_module. Additionally, there's an inconsistency in extra_include_paths: gen_trtllm_gen_gemm_module includes both FLASHINFER_CUBIN_DIR and FLASHINFER_CUBIN_DIR / include_path (lines 424-427), while gen_trtllm_low_latency_gemm_module only includes FLASHINFER_CUBIN_DIR / include_path (line 607).

  1. Extract the header download logic into a shared helper function
  2. Verify whether the include path difference is intentional or an oversight - if the symlink at flashinfer/trtllm/gemm/trtllmGen_gemm_export is needed, then FLASHINFER_CUBIN_DIR should be included here as well

Example refactor:

def download_trtllm_gemm_headers(artifact_path: str, checksum_hash: str):
    """Download and cache TRTLLM GEMM export headers."""
    include_path = f"{artifact_path}/include"
    checksum_path = f"{artifact_path}/checksums.txt"
    checksum = get_cubin(checksum_path, checksum_hash)
    assert checksum, f"Failed to get checksums.txt from {checksum_path}"
    
    meta_hash = get_meta_hash(checksum)
    header_name = "flashinferMetaInfo"
    metainfo = get_cubin(f"{include_path}/{header_name}.h", meta_hash)
    assert metainfo, f"{header_name}.h not found"
    
    header_files = [
        "GemmInterface.h",
        "GemmOptions.h",
        # ... rest of the list
    ]
    
    header_path = f"{include_path}/trtllmGen_gemm_export"
    for file in header_files:
        uri_path = f"{header_path}/{file}"
        file_hash = get_meta_hash(checksum, file)
        file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export" / file
        get_file(uri_path, file_hash, str(file_path))
    
    symlink_parent = str(jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/gemm")
    make_symlink("../../../trtllmGen_gemm_export", symlink_parent, "trtllmGen_gemm_export")
    
    return include_path
🧹 Nitpick comments (2)
flashinfer/jit/cubin_loader.py (1)

139-151: Clarify the inline comment.

The comment mentions "case-insensitive for the 'I' in Infer" but the implementation performs case-insensitive matching on the entire filename using .lower().endswith(target_file.lower()). This may confuse future maintainers.

Consider updating the comment to:

-        # Match specifically flashinferMetaInfo.h (case-insensitive for the 'I' in Infer)
+        # Match filename ending with target_file (case-insensitive)
flashinfer/artifacts.py (1)

82-127: Consider reducing log verbosity.

Line 126 logs the entire list of header files at INFO level, which could be verbose for directories with many headers. Consider logging the count instead or moving this to DEBUG level.

-    logger.info(f"result: {result}")
+    logger.info(f"Found {len(result)} header files")
+    logger.debug(f"Header files: {result}")
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between db2aacb and 4b87c01.

⛔ Files ignored due to path filters (11)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaKernelLauncher.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/MmaDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.h is excluded by !**/gen/**
📒 Files selected for processing (24)
  • csrc/trtllm_batched_gemm_runner.cu (2 hunks)
  • csrc/trtllm_gemm_runner.cu (3 hunks)
  • csrc/trtllm_low_latency_gemm_runner.cu (1 hunks)
  • flashinfer/artifacts.py (5 hunks)
  • flashinfer/jit/cubin_loader.py (2 hunks)
  • flashinfer/jit/fused_moe.py (3 hunks)
  • flashinfer/jit/gemm/core.py (4 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (0 hunks)
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h (0 hunks)
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h (0 hunks)
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h (0 hunks)
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h (0 hunks)
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h (0 hunks)
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h (0 hunks)
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h (0 hunks)
💤 Files with no reviewable changes (17)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/trtllm_gemm_runner.cu
  • csrc/trtllm_batched_gemm_runner.cu
🧬 Code graph analysis (3)
flashinfer/jit/fused_moe.py (1)
flashinfer/jit/cubin_loader.py (4)
  • get_cubin (227-246)
  • get_meta_hash (139-151)
  • make_symlink (195-203)
  • get_file (206-224)
flashinfer/jit/gemm/core.py (1)
flashinfer/jit/cubin_loader.py (3)
  • get_meta_hash (139-151)
  • make_symlink (195-203)
  • get_file (206-224)
flashinfer/artifacts.py (1)
flashinfer/jit/cubin_loader.py (1)
  • safe_urljoin (38-42)
🪛 Ruff (0.14.8)
flashinfer/jit/cubin_loader.py

209-209: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

flashinfer/artifacts.py

112-112: Consider moving this statement to an else block

(TRY300)

🔇 Additional comments (15)
csrc/trtllm_low_latency_gemm_runner.cu (1)

52-54: LGTM!

The initialization of the "valid" dimension fields mirrors the primary dimension fields, which aligns with the PR's objective to update TMA descriptor shape initialization.

csrc/trtllm_gemm_runner.cu (3)

122-124: LGTM!

The valid dimension fields are correctly initialized after the primary dimensions, maintaining consistency with the transpose logic already applied to mM and mN.


145-147: LGTM!

The initialization pattern is consistent with the getWorkspaceSizeInBytes method.


196-198: LGTM!

The initialization pattern is uniformly applied across all three methods (getWorkspaceSizeInBytes, run, and getValidTactics).

csrc/trtllm_batched_gemm_runner.cu (3)

149-151: LGTM!

The valid dimension fields are correctly initialized in the workspace sizing method.


343-345: LGTM!

The valid dimension fields are correctly initialized in the config validation method.


448-450: LGTM!

The valid dimension fields are correctly initialized in the config validation method.

flashinfer/artifacts.py (4)

139-150: LGTM!

The artifact paths have been updated with new hashes, and the old values are preserved as comments for reference or rollback. This aligns with the PR objective to pull trtllm-gen headers from Artifactory.


166-178: LGTM!

The checksums have been updated to match the new artifact paths. Good practice to preserve the old values as comments.


245-247: LGTM!

The header files are now included in the artifact listing, mirroring the pattern used for cubin files. This enables header files to be downloaded and verified alongside cubin binaries.


256-256: LGTM!

Adding explicit type annotation improves code clarity.

flashinfer/jit/gemm/core.py (2)

33-38: LGTM!

The new imports support the header file management functionality added in this PR.


424-427: LGTM!

The include paths are extended to support both the root cubin directory (for the symlink at flashinfer/trtllm/gemm/trtllmGen_gemm_export) and the artifact-specific include directory. This enables the compilation to locate the downloaded headers.

flashinfer/jit/fused_moe.py (2)

29-29: LGTM!

The new imports support the header file management functionality added in this PR.


298-303: LGTM!

The include paths are extended to support both the root cubin directory and the artifact-specific include directory, consistent with the pattern in gen_trtllm_gen_gemm_module.

Comment on lines +212 to +214
gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove duplicate initialization.

The valid dimension fields (mValidM, mValidN, mValidK) are initialized twice in this method with identical values. Lines 212-214 correctly initialize these fields immediately after mM, mN, mK are set. The second initialization at lines 251-253 is redundant since the dimension values haven't changed between these two points.

Apply this diff to remove the duplicate initialization:

  gemmData.mOutputBuffers.mPtrC = c;
  gemmData.mOutputBuffers.mPtrSfC = outSfC;

  int32_t multiProcessorCount;
  cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);

-  gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
-  gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
-  gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
-
  // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
  bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));

Also applies to: 251-253

🤖 Prompt for AI Agents
In csrc/trtllm_batched_gemm_runner.cu around lines 212-214 and again at 251-253,
the mProblemDimensions.mValidM/mValidN/mValidK fields are initialized twice with
the same values; remove the redundant initialization at lines 251-253 so the
valid dimension fields are only set once immediately after mM/mN/mK are
assigned.

Comment on lines +195 to +203
def make_symlink(src: str, symlink_parent: str, symlink_name: str) -> None:
"""
Create a symlink from {src} to {symlink_parent}/{symlink_name}.
If the symlink does not exist, create the parent directory and create the symlink.
"""
symlink_path = safe_urljoin(symlink_parent, symlink_name)
if not os.path.exists(symlink_path):
os.makedirs(symlink_parent, exist_ok=True)
os.symlink(src, symlink_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle broken symlinks correctly.

os.path.exists() returns False for broken symlinks, which would cause the function to attempt creating a symlink that already exists (resulting in an FileExistsError). Use os.path.lexists() or check with os.path.islink() to detect existing symlinks regardless of whether their target exists.

Apply this diff:

-    if not os.path.exists(symlink_path):
+    if not os.path.lexists(symlink_path):
         os.makedirs(symlink_parent, exist_ok=True)
         os.symlink(src, symlink_path)
🤖 Prompt for AI Agents
In flashinfer/jit/cubin_loader.py around lines 195 to 203, the current check
uses os.path.exists(symlink_path) which returns False for broken symlinks and
can cause attempts to recreate an existing symlink; change the existence check
to use os.path.lexists(symlink_path) (or os.path.islink(symlink_path)) so
existing symlinks are detected regardless of target validity, and only call
os.makedirs and os.symlink when the lexists/islink check indicates the path
truly does not exist to avoid FileExistsError.

⚠️ Potential issue | 🟠 Major

Use filesystem path operations instead of URL joining.

safe_urljoin is designed for URLs and may not handle filesystem paths correctly across platforms. For filesystem operations, prefer os.path.join or pathlib.Path.

Apply this diff:

-    symlink_path = safe_urljoin(symlink_parent, symlink_name)
+    symlink_path = os.path.join(symlink_parent, symlink_name)
🤖 Prompt for AI Agents
In flashinfer/jit/cubin_loader.py around lines 195 to 203, the code uses
safe_urljoin to build a filesystem symlink path which is for URLs and can
mis-handle platform paths; replace safe_urljoin with a filesystem-aware join
(prefer pathlib.Path or os.path.join) to construct symlink_path (e.g.,
Path(symlink_parent) / symlink_name), ensure the parent directory exists
(Path(symlink_parent).mkdir(parents=True, exist_ok=True) or os.makedirs), and
then create the symlink with os.symlink(src, str(symlink_path)); keep the
existing existence check (Path(symlink_path).exists() or os.path.exists) to
avoid overwriting.

Comment on lines +206 to +224
def get_file(
uri_path: str,
sha256: str,
file_path: str = None,
session=None,
) -> bytes:
"""
Load a file from local cache directory {file_path}, ensure that the sha256 signature matches.
Otherwise, download the file from {uri_path} and write to {file_path}.
"""

file = load_cubin(file_path, sha256)
if file:
return file
os.makedirs(os.path.dirname(file_path), exist_ok=True)
uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, uri_path)
logger.info(f"Fetching file from {uri}")
download_file(uri, file_path, session=session)
return load_cubin(file_path, sha256)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix type annotation and add validation for required parameter.

The file_path parameter is annotated as str = None but is used without null checks (lines 217, 220), which will cause an AttributeError if None is passed. Based on the function's usage in flashinfer/jit/gemm/core.py and flashinfer/jit/fused_moe.py, this parameter appears to be required.

Apply this diff:

 def get_file(
     uri_path: str,
     sha256: str,
-    file_path: str = None,
+    file_path: str,
     session=None,
 ) -> bytes:
🧰 Tools
🪛 Ruff (0.14.8)

209-209: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

🤖 Prompt for AI Agents
In flashinfer/jit/cubin_loader.py around lines 206 to 224, the file_path
parameter is annotated as optional (str = None) but used unguarded
(os.path.dirname, load_cubin), which will raise AttributeError if None; change
the signature to require file_path: str (remove the = None default) and add an
explicit validation at the start of the function that raises a ValueError with a
clear message if file_path is falsy (None or empty). Ensure callers still pass a
valid path (update call sites if necessary) so the function never receives None.

Comment on lines +233 to +265
header_files = [
"BatchedGemmEnums.h",
"BatchedGemmInterface.h",
"BatchedGemmOptions.h",
"Enums.h",
"GemmGatedActOptions.h",
"GemmOptions.h",
"KernelParams.h",
"KernelParamsDecl.h",
"KernelTraits.h",
"TmaDescriptor.h",
"trtllm/gen/CommonUtils.h",
"trtllm/gen/CudaArchDecl.h",
"trtllm/gen/CudaKernelLauncher.h",
"trtllm/gen/DtypeDecl.h",
"trtllm/gen/MmaDecl.h",
"trtllm/gen/SfLayoutDecl.h",
]

header_path = f"{include_path}/trtllmGen_bmm_export"
for file in header_files:
uri_path = f"{header_path}/{file}"
file_hash = get_meta_hash(checksum, file)
file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_bmm_export" / file
get_file(uri_path, file_hash, file_path)
# Create directory flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export pointing to trtllmGen_bmm_export

symlink_parent = str(
jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/batched_gemm"
)
make_symlink(
"../../../trtllmGen_bmm_export", symlink_parent, "trtllmGen_bmm_export"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add error handling for header file downloads.

Similar to the issue in flashinfer/jit/gemm/core.py, the get_file function returns empty bytes on failure, but there's no verification that downloads succeeded. This could lead to compilation failures with unclear error messages.

Consider adding validation:

for file in header_files:
    uri_path = f"{header_path}/{file}"
    file_hash = get_meta_hash(checksum, file)
    file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_bmm_export" / file
    result = get_file(uri_path, file_hash, str(file_path))
    if not result:
        raise RuntimeError(f"Failed to download header file: {file}")
🤖 Prompt for AI Agents
In flashinfer/jit/fused_moe.py around lines 233 to 265, the code calls get_file
for each header but does not validate the returned bytes, so failed downloads
(empty bytes) will go unnoticed; update the loop to capture get_file's return
value, verify it's non-empty, and raise a RuntimeError (including the filename
and preferably uri_path and file_hash) if the download failed; ensure get_file
is passed the correct file_path type (str if required by get_file) and stop
processing further files on the first failure.

Comment on lines +382 to +409

header_files = [
"GemmInterface.h",
"GemmOptions.h",
"Enums.h",
"KernelTraits.h",
"KernelParams.h",
"KernelParamsDecl.h",
"TmaDescriptor.h",
"trtllm/gen/CommonUtils.h",
"trtllm/gen/CudaKernelLauncher.h",
"trtllm/gen/DtypeDecl.h",
"trtllm/gen/MmaDecl.h",
"trtllm/gen/SfLayoutDecl.h",
"trtllm/gen/CudaArchDecl.h",
]

header_path = f"{include_path}/trtllmGen_gemm_export"
for file in header_files:
uri_path = f"{header_path}/{file}"
file_hash = get_meta_hash(checksum, file)
file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export" / file
get_file(uri_path, file_hash, file_path)
# Create directory flashinfer/trtllm/gemm/trtllmGen_gemm_export pointing to trtllmGen_gemm_export
symlink_parent = str(jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/gemm")
make_symlink(
"../../../trtllmGen_gemm_export", symlink_parent, "trtllmGen_gemm_export"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add error handling for header file downloads.

The get_file function returns empty bytes on failure (from load_cubin at line 224 of cubin_loader.py), but there's no verification that the download succeeded. This could lead to compilation failures with cryptic error messages.

Consider adding validation after the download loop:

header_path = f"{include_path}/trtllmGen_gemm_export"
downloaded_files = []
for file in header_files:
    uri_path = f"{header_path}/{file}"
    file_hash = get_meta_hash(checksum, file)
    file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export" / file
    result = get_file(uri_path, file_hash, str(file_path))
    if not result:
        raise RuntimeError(f"Failed to download header file: {file}")
    downloaded_files.append(file)
🤖 Prompt for AI Agents
In flashinfer/jit/gemm/core.py around lines 382 to 409, the loop that calls
get_file for each header does not verify the download succeeded (get_file can
return empty bytes), so add immediate validation after each get_file call:
capture the return value, if it is falsy/empty raise a RuntimeError naming the
missing header (e.g. f"Failed to download header file: {file}"), and optionally
collect successful filenames into a list; ensure you pass the correct path/type
to get_file as required by its signature and fail fast to avoid later cryptic
compilation errors.

@jimmyzho jimmyzho changed the title chore: pull trtllm-gen batch-gemm/gemm headers from artifactory; update tma descriptor shape init refactor: pull trtllm-gen batch-gemm/gemm headers from artifactory; update tma descriptor shape init Dec 17, 2025
@yzh119
Copy link
Collaborator

yzh119 commented Dec 18, 2025

@aleozlx can you help reviewing this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants