Skip to content

Main new triton#664

Open
zoranjovanovic-ns wants to merge 12 commits intomainfrom
main-new-triton
Open

Main new triton#664
zoranjovanovic-ns wants to merge 12 commits intomainfrom
main-new-triton

Conversation

@zoranjovanovic-ns
Copy link

Motivation

Integrate new Triton version 5a9cfa6ab3e59608d2f210e41e024994626a838c
PR created just for testing purposes

Technical Details

Test Plan

Test Result

Submission Checklist

sergachev and others added 12 commits March 10, 2026 05:33
Imported from GitHub PR openxla#38517

📝 Summary of Changes
Code deduplication

🎯 Justification
Cleanup

Copybara import of the project:

--
747b981 by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Deduplicate LaunchCudaKernel.

Merging this change closes openxla#38517

COPYBARA_INTEGRATE_REVIEW=openxla#38517 from openxla:dedup 747b981
PiperOrigin-RevId: 881372433
…fferOpPtr

Imported from GitHub PR openxla#38866

📝 Summary of Changes
Use addNestedPass for TritonAMDGPUOptimizeBufferOpPtr

🎯 Justification
Crash when running triton pipeline for ROCm

🚀 Kind of Contribution
Please remove what does not apply: 🐛 Bug Fix

📊 Benchmark (for Performance Improvements)
Please measure and include speedups for one of the public HLOs in
`compiler/xla/tools/benchmarks/hlo/`.

🧪 Unit Tests:
triton/support_test

🧪 Execution Tests:
What execution tests were added? For example, a new optimization should be
tested with an end-to-end execution test triggering the optimization and
asserting correctness. Please provide test cases running with at most 2 GPUs.

Copybara import of the project:

--
b5368cf by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Use addNestedPass for TritonAMDGPUOptimizeBufferOpPtr

Merging this change closes openxla#38866

COPYBARA_INTEGRATE_REVIEW=openxla#38866 from ROCm:rocm-fix-triton-pipeline b5368cf
PiperOrigin-RevId: 881374096
…sion

Imported from GitHub PR openxla#38836

Hot-Fix for breaking changes added in 5150bfd

📝 Summary of Changes
The ExecutableAbiVersion::FromDeviceDescription was only implemented for CUDA, returning UNIMPLEMENTED for all other platforms. This broke every ROCm compilation path after 5150bfd introduced mandatory ABI version derivation in GpuCompiler::RunBackend.

Add a CreateForRocm handler that returns a minimal ABI version with
only platform_name set to "ROCm" and no platform-specific version
info. This preserves pre-existing ROCm behavior (no compatibility
checks) until proper ABI versioning is designed for ROCm.

🎯 Justification
without this change JAX tests on ROCm platform fail with error:
```bash
jax.errors.JaxRuntimeError: UNIMPLEMENTED: Deriving the executable ABI version from the device description is not implemented for the target platform.
```
🚀 Kind of Contribution
🐛 Bug Fix

🧪 Unit Tests:
added rocm test to test //xla/stream_executor/abi:executable_abi_version_test

Copybara import of the project:

--
a24b40a by Manjunath Gaonkar <magaonka@amd.com>:

[ROCm] Add minimal ROCm support to ExecutableAbiVersion

The ExecutableAbiVersion::FromDeviceDescription was only implemented
for CUDA, returning UNIMPLEMENTED for all other platforms. This broke
every ROCm compilation path after 5150bfd introduced mandatory ABI
version derivation in GpuCompiler::RunBackend.

Add a CreateForRocm handler that returns a minimal ABI version with
only platform_name set to "ROCm" and no platform-specific version
info. This preserves pre-existing ROCm behavior (no compatibility
checks) until proper ABI versioning is designed for ROCm.

Add a unit test for the ROCm path.

Merging this change closes openxla#38836

COPYBARA_INTEGRATE_REVIEW=openxla#38836 from magaonka-amd:fix/rocm-executable-abi-version a24b40a
PiperOrigin-RevId: 881376820
…lective call

Imported from GitHub PR openxla#38932

Some times NCCL call can race with a cuBLAS initialization that happens before the collective operation, and this leads to deadlocks inside CUDA driver. The problem is a known bug in CUDA when lazy loading kernels can deadlock in done concurrently with other operations. Adding a rendezvous before and after first call to collective operations guarantees that NCCL and cuBLAS initialization doesn't race during the first call to XLA program. After the first execution all state is correctly initialized and all thunks can be executed concurrently.

```
Lock Holder — Thread 526 (LWP 223074): deep in libcuda resource management

  #0  libcuda_a39aa780397842086045a7b3bcbbdc8b990f1dc2 ()                   ← STUCK HERE
  #1  libcuda_78bd8cb6e161b9259f49c55b198c2ebb6b6ca573 ()
  #2  libcuda_415ad33a27adbd9e7ac386d1436fb472f9a186a7 ()
  #3  libcuda_6604415c014ac9883d5d8d7bd3eaec3cc6f61dd5 ()
  #4  libcuda_a8b7e1669b1e305583a2c856187aa476dd35a0b6 ()
  #5  libcuda_da6ad49bd9734e0ae200848c0cf7d7cc9853ac2e ()
  #6  libcuda_6c7099fa0c13dab048e8a079f9c9d18af0a9c817 ()
  #7  libcuda_755905f93113a0709483448e98a77986387cc1d2 ()                   ← rwlock acquired here
  #8  libcuda_41c3d2e54bbc2cab9a8c690af0d80fdf54a7bea4 ()
  #9  cuKernelSetAttribute ()
  #10 libcublasLt_4379571754809d5aed869096e4697c8475997ae4 ()
  #11 libcublasLt_461e95d2ff5949df0a29da83ada6c345d9c1ebb0 ()
  #12 libcublasLt_c7ad85ac4de152e94520edbfab34aa5b3ebef088 ()
  #13 libcublasLt_9c76814a7f6623e89ce420723203783b4100e387 ()
  #14 cublasLtTSSMatmul ()
  #15 libcublas_9cdbde38fcd4e96fade46520882865abc4593db8 ()
  #16 libcublas_dcfa9499b939b7c46fe3e9c6c04d82164096c00d ()
  #17 libcublas_cc7e69e32c54abdbb393ffdb179bca77b461347e ()
  #18 libcublas_5c324fb9d1a4eecb9e11bd58d3b1b5be99cdcfd2 ()
  #19 libcublas_c381464be92fab0a975a57336c63cd0671b27bc8 ()
  #20 cublasGemmEx ()
  #21 stream_executor::cuda::CUDABlas::DoBlasGemmWithAlgorithm(...) ()
  #22 xla::gpu::RunGemm(...) ()
  #23 xla::gpu::GemmThunk::ExecuteOnStream(Thunk::ExecuteParams const&) ()
  #24 xla::gpu::SequentialThunk::ExecuteOnStream(...) ()
  #25 xla::gpu::WhileThunk::ExecuteOnStream(...) ()
  #26 xla::gpu::SequentialThunk::ExecuteOnStream(...) ()
  #27 xla::gpu::(anonymous namespace)::ExecuteThunksImpl(...) ()
  #28 xla::gpu::GpuExecutable::ExecuteThunks(...) ()
  #29 xla::StreamExecutorGpuClient::RunAsync(...) ()
  #30 xla::PjRtStreamExecutorRawLoadedExecutable::Execute(...)::$_1::operator()() ()
  #31 xla::PjRtStreamExecutorRawLoadedExecutable::Execute(...) && ()
  #32 xla::CommonPjRtLoadedExecutable::ExecuteLaunch(...) const ()
  #33 absl::internal_any_invocable::RemoteInvoker<..., CommonPjRtLoadedExecutable::Execute(...)::$_3&>(...) ()
  #34 xla::WorkerThread::WorkLoop() ()
  #35 tsl::(anonymous namespace)::PThread::ThreadFn(void*) ()
  #36 start_thread () at pthread_create.c:447
  #37 clone3 () at clone3.S:78

  rwlock Waiters (×7 identical) — all blocked on rwlock 0x2265a900

  #0  __futex_abstimed_wait_common (futex_word=0x2265a90c, expected=3)
  #2  __pthread_rwlock_wrlock_full64 (rwlock=0x2265a900)
  #3  ___pthread_rwlock_wrlock (rwlock=0x2265a900)
  #4  libcuda_755905f93113a0709483448e98a77986387cc1d2 ()
  #5  libcuda_41c3d2e54bbc2cab9a8c690af0d80fdf54a7bea4 ()
  #6  cuKernelSetAttribute ()
  #7-#14 [cublasLt → cublasLtTSSMatmul]
  #15-#20 [cublas → cublasGemmEx]
  #21 stream_executor::cuda::CUDABlas::DoBlasGemmWithAlgorithm(...) ()
  #22 xla::gpu::RunGemm(...) ()
  #23 xla::gpu::GemmThunk::ExecuteOnStream(...) ()
  #24-#37 [SequentialThunk → WhileThunk → ExecuteThunksImpl → GpuExecutable → RunAsync → Execute → WorkLoop]
```
Copybara import of the project:

--
37a1c5d by Eugene Zhulenev <ezhulenev@openxla.org>:

[xla:gpu] Add rendezvous before AND after first collective call

Merging this change closes openxla#38932

COPYBARA_INTEGRATE_REVIEW=openxla#38932 from ezhulenev:rendezvous-around-collective-call 37a1c5d
PiperOrigin-RevId: 881376834
…rk registration

Instead of an index.

PiperOrigin-RevId: 881383226
Lifting and deduping works with non-inlined modules as they do a pre-order walk on the module op on its operations:
- FuncOp: The walk already (special) handles the shardings of arguments and results.
- CallOp: The walk handles the shardings of call results just as another operation.
- ManualComputationOp: At this point, we keep that manual computations are converted into ManualCommputation ops. The walk handles the shardings of manual computation ops as before.
- NamedComputationOp. The walk has a special handling for the shardings of named computations (as they are shardable data flow ops) After this change, the non-main functions are not named computations any longer. So the calls need to be handled as FuncOps and CallOps instead of NamedComputations ops.

PiperOrigin-RevId: 881398465
…le fingerprinting logic

PiperOrigin-RevId: 881398989
…ehlo round trip.

ShardMapImport pass does not expect to have 'inlineable' manual computations.

Also: Delete export+import test. It is confusing, does not align with the prod and not-strictly supported. For example, import does not expect to have inlineable manual computations, and export adds inlineable manual computations.
PiperOrigin-RevId: 881412124
Remove timeout long for all convolution tests.

PiperOrigin-RevId: 881418969
const GpuCliqueKey& clique_key,
se::Stream& stream,
Communicator& comm) override;
bool RequiresRendezvous() const override { return true; }
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Unnecessary rendezvous on collective-kernel path

RequiresRendezvous() unconditionally returns true, but the old RunCollective returned false when use_collective_kernel was true (lines 206-208 of all_reduce_thunk.cc). The collective kernel path doesn't use NCCL and does its own synchronization.

With this change, the first call will now perform both pre-call and post-call rendezvous even on the collective kernel path. This won't deadlock (all participants reach the rendezvous), but it adds unnecessary synchronization overhead on the first collective kernel invocation.

Consider:

Suggested change
bool RequiresRendezvous() const override { return true; }
bool RequiresRendezvous() const override;

…with an implementation that checks use_collective_kernel at runtime, similar to the old behavior.

Communicator& comm) override;
// No rendezvous needed when using P2P memcpy in local mode instead of NCCL.
bool RequiresRendezvous() const override { return !p2p_memcpy_enabled_; }

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potential bug: Rendezvous skipped when p2p_memcpy enabled but operation is non-local

The old RunCollective returned true (rendezvous needed) when falling through to the NCCL path at line 258 of all_to_all_thunk.cc, and false only when actually using the memcpy path. The decision was made at runtime based on is_local(...) && p2p_memcpy_enabled_.

The new RequiresRendezvous() returns !p2p_memcpy_enabled_, which means when p2p_memcpy_enabled_ is true but the operation is non-local (falls through to NCCL), the first-call rendezvous is skipped. This could re-introduce the NCCL deadlock that the rendezvous was designed to prevent.

Should this check whether the operation will actually use the memcpy path (i.e., !(is_local(...) && p2p_memcpy_enabled_)) rather than just !p2p_memcpy_enabled_?

Communicator& comm) override;
// No rendezvous needed when using P2P memcpy in local mode instead of NCCL.
bool RequiresRendezvous() const override { return !p2p_memcpy_enabled_; }

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potential bug: Same rendezvous concern as AllToAllStartThunk

The old RunCollective returned !use_memcpy where use_memcpy = is_local_peer && recv_ptr_map_.IsInitialized(current_id) && p2p_memcpy_enabled_ — a runtime decision based on three conditions.

The new RequiresRendezvous() only checks !p2p_memcpy_enabled_. When p2p_memcpy_enabled_ is true but either the peer is non-local or recv_ptr_map_ is not initialized for the current ID, the code falls through to the NCCL path (line 382 of collective_permute_thunk.cc) without the first-call rendezvous. This could lead to NCCL deadlocks in those cases.

label, kind(), params.collective_params->run_id.ToInt(),
num_local_participants, rank.value(), clique_key.ToString());

auto rendezvous_key = FirstCallRendezvousKey{clique_key};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Pre-call and post-call rendezvous share the same key type

Both invocations at lines 417 and 440 use FirstCallRendezvousKey{clique_key}. The isolation relies solely on separate RendezvousFlag objects (pre_call_rendezvous_flag_ vs post_call_rendezvous_flag_). This works correctly, but using distinct key types (e.g., PreCallRendezvousKey / PostCallRendezvousKey) would make the intent clearer and be more robust against future refactors.

device_ordinal, label, kind(),
params.collective_params->run_id.ToInt());

const xla::DebugOptions debug_options = xla::GetDebugOptionsFromFlags();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: GetDebugOptionsFromFlags() called twice on first execution

This is now invoked inside the lambda, which is called twice per ExecuteOnStream (lines 417 and 440) on the first execution. The old code called it once. Consider hoisting the call above the lambda or caching the result to avoid redundant flag parsing.

@claude
Copy link

claude bot commented Mar 10, 2026

Review Summary: This PR integrates Triton version 5a9cfa6ab with upstream XLA changes. Key concern: the refactored RequiresRendezvous() in AllToAllStartThunk and CollectivePermuteStartThunk may skip first-call NCCL rendezvous when p2p_memcpy_enabled_ is true but the operation falls through to NCCL (non-local topology). See inline comments for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants