-
Notifications
You must be signed in to change notification settings - Fork 546
Support torch.distributed.scatter collective #9365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
@@ -360,7 +360,6 @@ def test_barrier(self): | |||
'allreduce_coalesced', | |||
'alltoall', | |||
'gather', | |||
'scatter', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if there's a reason to add a scatter
test to this file. The test in test/pjrt/test_collective_ops_tpu.py
is more robust in that it tests the actual result. The tests in this file just check that the IR looks correct, which can be misleading (as was the case for send/recv).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see other operation calls to things like 'gather' and 'alltoall'. What is the reasoning to keep them and remove scatter?
Would it perhaps to improve test documentation on what it does rather than remove 'scatter'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test checks that the methods are unimplemented for XLAProcessGroup. scatter
is now implemented so it is being removed.
What I'm wondering if if there's a reason to add a new test to this file that checks if group.scatter
outputs the expected HLO. That's what other tests in this file do, but it's not clear to me what value they add beyond the existing tests in test_collective_ops_tpu.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which tests are calling this function in test/pjrt/test_collective_ops_tpu.py
? The tests I see there are for reduce_scatter
and ReduceScatter
which are seem to me to be higher level abstractions with other things happening. Perhaps I am missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test added in this PR, test_collective_ops_tpu:test_scatter
calls torch.dist.scatter
which calls ProcessGroupXLA.scatter
.
Abstracting things a bit, we have a function A (torch.dist.scatter) that wraps function B (ProcessGroupXLA.scatter). We have a test for A. Should we have a test for B as well? In many cases the answer is yes, especially if B is used in multiple places, if the test for B is logically self-contained and informative, if A adds significant additional logic that we want to test without worrying having to think about B, etc. In this case I'm advocating for not testing B, because
- The test for B is unreliable (as we saw for send/recv, the IR might look reasonable but not work)
- A is a fairly thin wrapper around B
- The contents of A are in upstream PT, so we're not testing them independently of B
- B isn't used anywhere else, nor would it be directly called by other code
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I misunderstood a couple things. The clarification here helps a lot. It is preferable to test the calling API to a method rather than the method itself. Given that torch.dist.scatter
is serving as the external API to ProcessGroupXLA.scatter
, it is reasonable to only test it.
I can see this popping up as an issue on some coverage tests, so I would add an explicit comment to the tests in addition to your comment on torch.distributed.scatter
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly questions, and a requests for extra documentation
@@ -360,7 +360,6 @@ def test_barrier(self): | |||
'allreduce_coalesced', | |||
'alltoall', | |||
'gather', | |||
'scatter', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see other operation calls to things like 'gather' and 'alltoall'. What is the reasoning to keep them and remove scatter?
Would it perhaps to improve test documentation on what it does rather than remove 'scatter'?
dist.init_process_group("xla", init_method='xla://') | ||
device = torch_xla.device() | ||
world_size = xr.world_size() | ||
if xr.global_ordinal() == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor readability improvement:
tensors = None
if xr.global_ordinal() == 0:
tensors = [
torch.tensor([i], device=device, dtype=torch.float)
for i in range(world_size)
]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow-up seems good. Let me know if you have any questions on https://github.com/pytorch/xla/pull/9365/files#r2151351304.
Otherwise, LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow-up seems good. Let me know if you have any questions on https://github.com/pytorch/xla/pull/9365/files#r2151351304.
Otherwise, LGTM.
One minor thing: I believe the tests failing are due to flakyness. Can you confirm?
#9315
XLA doesn't have a distributed Scatter op but we can put dummy tensor lists on the non-source rank and use reduce_scatter