Skip to content

Support torch.distributed.scatter collective #9365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

bfolie
Copy link
Collaborator

@bfolie bfolie commented Jun 16, 2025

#9315

XLA doesn't have a distributed Scatter op but we can put dummy tensor lists on the non-source rank and use reduce_scatter

@bfolie bfolie requested review from bhavya01 and pgmoka June 16, 2025 17:02
@@ -360,7 +360,6 @@ def test_barrier(self):
'allreduce_coalesced',
'alltoall',
'gather',
'scatter',
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if there's a reason to add a scatter test to this file. The test in test/pjrt/test_collective_ops_tpu.py is more robust in that it tests the actual result. The tests in this file just check that the IR looks correct, which can be misleading (as was the case for send/recv).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see other operation calls to things like 'gather' and 'alltoall'. What is the reasoning to keep them and remove scatter?

Would it perhaps to improve test documentation on what it does rather than remove 'scatter'?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test checks that the methods are unimplemented for XLAProcessGroup. scatter is now implemented so it is being removed.

What I'm wondering if if there's a reason to add a new test to this file that checks if group.scatter outputs the expected HLO. That's what other tests in this file do, but it's not clear to me what value they add beyond the existing tests in test_collective_ops_tpu.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which tests are calling this function in test/pjrt/test_collective_ops_tpu.py? The tests I see there are for reduce_scatter and ReduceScatter which are seem to me to be higher level abstractions with other things happening. Perhaps I am missing something.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test added in this PR, test_collective_ops_tpu:test_scatter calls torch.dist.scatter which calls ProcessGroupXLA.scatter.

Abstracting things a bit, we have a function A (torch.dist.scatter) that wraps function B (ProcessGroupXLA.scatter). We have a test for A. Should we have a test for B as well? In many cases the answer is yes, especially if B is used in multiple places, if the test for B is logically self-contained and informative, if A adds significant additional logic that we want to test without worrying having to think about B, etc. In this case I'm advocating for not testing B, because

  • The test for B is unreliable (as we saw for send/recv, the IR might look reasonable but not work)
  • A is a fairly thin wrapper around B
  • The contents of A are in upstream PT, so we're not testing them independently of B
  • B isn't used anywhere else, nor would it be directly called by other code

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I misunderstood a couple things. The clarification here helps a lot. It is preferable to test the calling API to a method rather than the method itself. Given that torch.dist.scatter is serving as the external API to ProcessGroupXLA.scatter, it is reasonable to only test it.

I can see this popping up as an issue on some coverage tests, so I would add an explicit comment to the tests in addition to your comment on torch.distributed.scatter.

Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly questions, and a requests for extra documentation

@@ -360,7 +360,6 @@ def test_barrier(self):
'allreduce_coalesced',
'alltoall',
'gather',
'scatter',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see other operation calls to things like 'gather' and 'alltoall'. What is the reasoning to keep them and remove scatter?

Would it perhaps to improve test documentation on what it does rather than remove 'scatter'?

@bfolie bfolie requested review from pgmoka and ghpvnist June 18, 2025 17:32
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
if xr.global_ordinal() == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor readability improvement:

tensors = None
if xr.global_ordinal() == 0:
      tensors = [
          torch.tensor([i], device=device, dtype=torch.float)
          for i in range(world_size)
      ]

Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up seems good. Let me know if you have any questions on https://github.com/pytorch/xla/pull/9365/files#r2151351304.

Otherwise, LGTM

Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up seems good. Let me know if you have any questions on https://github.com/pytorch/xla/pull/9365/files#r2151351304.

Otherwise, LGTM.

One minor thing: I believe the tests failing are due to flakyness. Can you confirm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants