-
Notifications
You must be signed in to change notification settings - Fork 31
Refactor CCL APIs to align with torch.distributed conventions #326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This refactor reorders parameters and adds support for process groups: API Changes: - all_reduce: (out, in, op=SUM, group=None, async_op=False, config=None, workspace=None) - reduce_scatter: (out, in, op=SUM, group=None, async_op=False, config=None) - all_gather: (out, in, group=None, async_op=False, config=None) - all_to_all: (out, in, group=None, async_op=False, config=None) New Features: - Add ReduceOp enum (SUM, PRODUCT, MIN, MAX, etc.) matching torch.distributed - Add extract_group_info() helper to extract rank_start/rank_stride from ProcessGroup - Support strided process groups (e.g., TP groups [0,1,2,3] or DP groups [0,4,8,12]) - op parameter validates only SUM is used (other ops to be added later) Kernel Changes: - All CCL kernels now accept rank_start and rank_stride constexpr parameters - Kernel loops updated to iterate using group-aware rank calculation - Ring all-reduce computes next_rank on host side for group support Backward Compatibility: - Existing code using keyword arguments (config=...) continues to work - torch.distributed compatible parameter ordering (group before config)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request refactors the CCL (Collective Communication Library) APIs to align with torch.distributed conventions by reordering parameters and adding support for process groups. However, the implementation contains several critical bugs that prevent process groups from working correctly.
Changes:
- Adds ReduceOp enum matching torch.distributed semantics
- Reorders API parameters to match torch.distributed: (out, in, op, group, async_op, config)
- Adds extract_group_info() helper to extract rank/stride information from ProcessGroup
- Updates all CCL kernels to accept rank_start and rank_stride parameters for group-aware rank calculation
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| iris/iris.py | Updated all CCL method signatures to add op and group parameters with reordered arguments |
| iris/experimental/iris_gluon.py | Updated CCL method signatures (missing group parameter in all_gather) |
| iris/ccl/init.py | Added ReduceOp to exports |
| iris/ccl/utils.py | Added ReduceOp enum and extract_group_info() helper function |
| iris/ccl/all_reduce.py | Updated kernels and function to support group parameters with rank_start/rank_stride |
| iris/ccl/reduce_scatter.py | Updated kernel and function to support group parameters |
| iris/ccl/all_gather.py | Updated kernel and function to support group parameters |
| iris/ccl/all_to_all.py | Updated Triton and Gluon kernels and function to support group parameters |
| rank_start: tl.constexpr, | ||
| rank_stride: tl.constexpr, |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: cur_rank parameter semantic is inconsistent with process groups. When using process groups, cur_rank (passed from line 787, 816, 858, 885, 911) is rank_in_group, but iris.atomic_add, iris.load, iris.store, and iris.atomic_cas operations throughout the kernels expect global ranks for heap_bases indexing and pointer translation. Consider adding cur_rank_global as a separate constexpr parameter.
| rank_start: tl.constexpr, | ||
| rank_stride: tl.constexpr, |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: cur_rank parameter semantic is inconsistent with process groups. When using process groups, cur_rank (passed from line 235) is rank_in_group, but iris.load operations (lines 108, 112, 126, 130) expect global ranks for heap_bases indexing and pointer translation. Consider adding cur_rank_global as a separate constexpr parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
| assert my_group is not None, f"Rank {rank} not in any group" | ||
|
|
||
| group_ranks = dist.get_process_group_ranks(my_group) | ||
| group_size = len(group_ranks) |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable group_size is not used.
| group_size = len(group_ranks) |
| assert my_group is not None | ||
|
|
||
| group_ranks = dist.get_process_group_ranks(my_group) | ||
| group_size = len(group_ranks) |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable group_size is not used.
| group_size = len(group_ranks) |
This refactor reorders parameters and adds support for process groups:
API Changes:
New:
Kernel Changes: