Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion iris/ccl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
"""

from .config import Config
from .utils import ReduceOp

__all__ = ["Config"]
__all__ = ["Config", "ReduceOp"]
48 changes: 35 additions & 13 deletions iris/ccl/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import triton.language as tl
import iris
from .config import Config
from .utils import extract_group_info


@triton.jit()
Expand All @@ -24,7 +25,10 @@ def persistent_all_gather(
stride_out_n,
heap_bases: tl.tensor,
cur_rank: tl.constexpr,
cur_rank_global: tl.constexpr,
world_size: tl.constexpr,
rank_start: tl.constexpr,
rank_stride: tl.constexpr,
Comment on lines +30 to +31
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: cur_rank parameter semantic is inconsistent with process groups. When using process groups, cur_rank (passed from line 209) is rank_in_group, but iris.store operations (line 134) expect global ranks for heap_bases indexing and pointer translation. Consider adding cur_rank_global as a separate constexpr parameter.

Copilot uses AI. Check for mistakes.
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
Expand All @@ -47,8 +51,9 @@ def persistent_all_gather(
stride_in_m, stride_in_n: Strides for input tensor
stride_out_m, stride_out_n: Strides for output tensor
heap_bases: Heap base pointers for all ranks
cur_rank: Current rank
world_size: Total number of ranks
cur_rank: Current rank within the group (for comparisons)
cur_rank_global: Global rank (for iris IPC operations)
world_size: Total number of ranks in the group
BLOCK_SIZE_M, BLOCK_SIZE_N: Block sizes for tiling
GROUP_SIZE_M: Group size for M dimension tiling
COMM_SMS: Number of SMs for communication
Expand Down Expand Up @@ -100,7 +105,9 @@ def persistent_all_gather(

# Send local shard data to all destination ranks
# Each rank's input goes to output[cur_rank * M : (cur_rank + 1) * M, :] on all ranks
for rank in tl.static_range(world_size):
for i in tl.static_range(world_size):
target_rank = rank_start + i * rank_stride

# Compute global output row indices: offset by cur_rank * M
rm_output = rm_input + cur_rank * M

Expand All @@ -117,22 +124,30 @@ def persistent_all_gather(
output_ptr_target = output_ptr + output_offset
output_ptr_target = tl.multiple_of(output_ptr_target, (BLOCK_SIZE_M, BLOCK_SIZE_N))

if rank == cur_rank:
# Local destination: use direct store
if i == cur_rank:
# Local destination (i == rank_in_group): use direct store
tl.store(output_ptr_target, data, mask=combined_mask, cache_modifier=".wt")
else:
# Remote destination: use iris.store to send data to remote destination
# Use cur_rank_global for iris IPC operations
iris.store(
output_ptr_target,
data,
cur_rank,
rank,
cur_rank_global,
target_rank,
heap_bases,
mask=combined_mask,
)


def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False):
def all_gather(
output_tensor,
input_tensor,
shmem,
group=None,
async_op=False,
config=None,
):
"""
Internal all-gather collective operation implementation.

Expand All @@ -148,10 +163,12 @@ def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False):
output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs
input_tensor: Input tensor of shape (M, N) - local rank's data to send
shmem: Iris shmem context
config: Config instance with kernel parameters (default: None).
If None, uses default Config values.
group: ProcessGroup or None. If None, uses all ranks in shmem context.
Default: None.
async_op: If False, performs a barrier at the end. If True, returns immediately.
Default: False.
config: Config instance with kernel parameters (default: None).
If None, uses default Config values.
"""
# Use provided config or create default one
if config is None:
Expand All @@ -165,8 +182,10 @@ def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False):
"Use default config (use_gluon=False)."
)

rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
# Extract group information
# rank_in_group: position within the group (0, 1, 2, ...) - used for comparisons
# rank_global: global rank across all processes - used for iris IPC operations
rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem)

M, N = input_tensor.shape[:2]
expected_output_shape = (world_size * M, N)
Expand All @@ -192,8 +211,11 @@ def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False):
stride_out_m,
stride_out_n,
heap_bases,
rank,
rank_in_group,
rank_global,
world_size,
rank_start,
rank_stride,
config.block_size_m,
config.block_size_n,
config.swizzle_size,
Expand Down
Loading
Loading